diff --git a/doc/conf.py b/doc/conf.py index 9ab1966b70e73..903ea36b4dd18 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -475,6 +475,7 @@ def add_js_css_files(app, pagename, templatename, context, doctree): "auto_examples/linear_model/plot_lasso_coordinate_descent_path.py": ( "auto_examples/linear_model/plot_lasso_lasso_lars_elasticnet_path.py" ), + "auto_examples/linear_model/plot_ols_3d": ("auto_examples/linear_model/plot_ols"), } html_context["redirects"] = redirects for old_link in redirects: diff --git a/examples/linear_model/plot_ols.py b/examples/linear_model/plot_ols.py index 8aaa35ed8d899..aeb8e986459fc 100644 --- a/examples/linear_model/plot_ols.py +++ b/examples/linear_model/plot_ols.py @@ -1,63 +1,97 @@ """ -========================================================= -Linear Regression Example -========================================================= -The example below uses only the first feature of the `diabetes` dataset, -in order to illustrate the data points within the two-dimensional plot. -The straight line can be seen in the plot, showing how linear regression -attempts to draw a straight line that will best minimize the -residual sum of squares between the observed responses in the dataset, -and the responses predicted by the linear approximation. - -The coefficients, residual sum of squares and the coefficient of -determination are also calculated. +============================== +Ordinary Least Squares Example +============================== +This example shows how to use the ordinary least squares (OLS) model +called :class:`~sklearn.linear_model.LinearRegression` in scikit-learn. + +For this purpose, we use a single feature from the diabetes dataset and try to +predict the diabetes progression using this linear model. We therefore load the +diabetes dataset and split it into training and test sets. + +Then, we fit the model on the training set and evaluate its performance on the test +set and finally visualize the results on the test set. """ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -import matplotlib.pyplot as plt -import numpy as np - -from sklearn import datasets, linear_model +# %% +# Data Loading and Preparation +# ---------------------------- +# +# Load the diabetes dataset. For simplicity, we only keep a single feature in the data. +# Then, we split the data and target into training and test sets. +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split + +X, y = load_diabetes(return_X_y=True) +X = X[:, [2]] # Use only one feature +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=20, shuffle=False) + +# %% +# Linear regression model +# ----------------------- +# +# We create a linear regression model and fit it on the training data. Note that by +# default, an intercept is added to the model. We can control this behavior by setting +# the `fit_intercept` parameter. +from sklearn.linear_model import LinearRegression + +regressor = LinearRegression().fit(X_train, y_train) + +# %% +# Model evaluation +# ---------------- +# +# We evaluate the model's performance on the test set using the mean squared error +# and the coefficient of determination. from sklearn.metrics import mean_squared_error, r2_score -# Load the diabetes dataset -diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True) - -# Use only one feature -diabetes_X = diabetes_X[:, np.newaxis, 2] +y_pred = regressor.predict(X_test) -# Split the data into training/testing sets -diabetes_X_train = diabetes_X[:-20] -diabetes_X_test = diabetes_X[-20:] +print(f"Mean squared error: {mean_squared_error(y_test, y_pred):.2f}") +print(f"Coefficient of determination: {r2_score(y_test, y_pred):.2f}") -# Split the targets into training/testing sets -diabetes_y_train = diabetes_y[:-20] -diabetes_y_test = diabetes_y[-20:] - -# Create linear regression object -regr = linear_model.LinearRegression() - -# Train the model using the training sets -regr.fit(diabetes_X_train, diabetes_y_train) +# %% +# Plotting the results +# -------------------- +# +# Finally, we visualize the results on the train and test data. +import matplotlib.pyplot as plt -# Make predictions using the testing set -diabetes_y_pred = regr.predict(diabetes_X_test) +fig, ax = plt.subplots(ncols=2, figsize=(10, 5), sharex=True, sharey=True) -# The coefficients -print("Coefficients: \n", regr.coef_) -# The mean squared error -print("Mean squared error: %.2f" % mean_squared_error(diabetes_y_test, diabetes_y_pred)) -# The coefficient of determination: 1 is perfect prediction -print("Coefficient of determination: %.2f" % r2_score(diabetes_y_test, diabetes_y_pred)) +ax[0].scatter(X_train, y_train, label="Train data points") +ax[0].plot( + X_train, + regressor.predict(X_train), + linewidth=3, + color="tab:orange", + label="Model predictions", +) +ax[0].set(xlabel="Feature", ylabel="Target", title="Train set") +ax[0].legend() -# Plot outputs -plt.scatter(diabetes_X_test, diabetes_y_test, color="black") -plt.plot(diabetes_X_test, diabetes_y_pred, color="blue", linewidth=3) +ax[1].scatter(X_test, y_test, label="Test data points") +ax[1].plot(X_test, y_pred, linewidth=3, color="tab:orange", label="Model predictions") +ax[1].set(xlabel="Feature", ylabel="Target", title="Test set") +ax[1].legend() -plt.xticks(()) -plt.yticks(()) +fig.suptitle("Linear Regression") plt.show() + +# %% +# Conclusion +# ---------- +# +# The trained model corresponds to the estimator that minimizes the mean squared error +# between the predicted and the true target values on the training data. We therefore +# obtain an estimator of the conditional mean of the target given the data. +# +# Note that in higher dimensions, minimizing only the squared error might lead to +# overfitting. Therefore, regularization techniques are commonly used to prevent this +# issue, such as those implemented in :class:`~sklearn.linear_model.Ridge` or +# :class:`~sklearn.linear_model.Lasso`. diff --git a/examples/linear_model/plot_ols_3d.py b/examples/linear_model/plot_ols_3d.py deleted file mode 100644 index cd848f659e8d8..0000000000000 --- a/examples/linear_model/plot_ols_3d.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -========================================================= -Sparsity Example: Fitting only features 1 and 2 -========================================================= - -Features 1 and 2 of the diabetes-dataset are fitted and -plotted below. It illustrates that although feature 2 -has a strong coefficient on the full model, it does not -give us much regarding `y` when compared to just feature 1. -""" - -# Authors: The scikit-learn developers -# SPDX-License-Identifier: BSD-3-Clause - -# %% -# First we load the diabetes dataset. - -import numpy as np - -from sklearn import datasets - -X, y = datasets.load_diabetes(return_X_y=True) -indices = (0, 1) - -X_train = X[:-20, indices] -X_test = X[-20:, indices] -y_train = y[:-20] -y_test = y[-20:] - -# %% -# Next we fit a linear regression model. - -from sklearn import linear_model - -ols = linear_model.LinearRegression() -_ = ols.fit(X_train, y_train) - - -# %% -# Finally we plot the figure from three different views. - -import matplotlib.pyplot as plt - -# unused but required import for doing 3d projections with matplotlib < 3.2 -import mpl_toolkits.mplot3d # noqa: F401 - - -def plot_figs(fig_num, elev, azim, X_train, clf): - fig = plt.figure(fig_num, figsize=(4, 3)) - plt.clf() - ax = fig.add_subplot(111, projection="3d", elev=elev, azim=azim) - - ax.scatter(X_train[:, 0], X_train[:, 1], y_train, c="k", marker="+") - ax.plot_surface( - np.array([[-0.1, -0.1], [0.15, 0.15]]), - np.array([[-0.1, 0.15], [-0.1, 0.15]]), - clf.predict( - np.array([[-0.1, -0.1, 0.15, 0.15], [-0.1, 0.15, -0.1, 0.15]]).T - ).reshape((2, 2)), - alpha=0.5, - ) - ax.set_xlabel("X_1") - ax.set_ylabel("X_2") - ax.set_zlabel("Y") - ax.xaxis.set_ticklabels([]) - ax.yaxis.set_ticklabels([]) - ax.zaxis.set_ticklabels([]) - - -# Generate the three different figures from different views -elev = 43.5 -azim = -110 -plot_figs(1, elev, azim, X_train, ols) - -elev = -0.5 -azim = 0 -plot_figs(2, elev, azim, X_train, ols) - -elev = -0.5 -azim = 90 -plot_figs(3, elev, azim, X_train, ols) - -plt.show()