-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
DOC Release highlights for 1.8 #32809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
55 commits
Select commit
Hold shift + click to select a range
5586be3
DOC Release highlights for 1.8
lesteve e42eb35
DOC linear models
lorentzenchr a1d66ea
Avoid FutureWarning
lesteve 11902e1
minor tweak
lesteve f7cacc6
Improve array API support section
lesteve 8a52d30
Improve array API support section
lesteve 7180b5e
Apply suggestions from code review
lesteve 1c9845f
lint
lesteve 2642813
Apply suggestions from code review
lesteve d2aa37f
lint
lesteve 77fa4a2
Add doc highlights for ClassicalMDS
OmarManzoor ce4abc4
Apply suggestions from code review
lesteve a7289b0
Tweak free-threaded + lint
lesteve 0817d1a
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve 1364eed
add paragraph for html repr
jeremiedbb ae920e0
lint
jeremiedbb 69e92ad
Remove code and add plot for ClassicalMDS
OmarManzoor 3205220
Correct HTML repr highlights
lesteve 8048b5b
Add section for DecisionTreeRegressor with MAE
lesteve 04b68b9
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve 8ac0351
lint
lesteve 46c3425
Link to array API example notebook
ogrisel 17e904b
Apply suggestions from code review about minor changes
lesteve 86bfd63
Apply suggestions from code review array API
lesteve 82a689a
rst fix + tweak
lesteve 68d1408
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve 547fd2a
tweak
lesteve 01eda05
tweak
lesteve 9684a44
Tweak array API support title
lesteve be11549
Rewrite free-threaded section
lesteve bccf6cd
Add explicit n_jobs
lesteve 4f2b83d
tweak free-threading
lesteve d819bfa
Add uncalibrated to temperature scaling
lesteve 346a490
Remove ConvergenceWarning on some random states
lesteve 918b95d
Improve grammar
betatim 3a909cd
Apply clarification
betatim 035a1df
Fix
betatim 4a0f8b1
Apply suggestions from code review
lesteve 508c186
lint
lesteve fae1e7b
More style
betatim 0ce288d
Tweak to make it clearer that TableVectorizer is from skrub
lesteve fec5697
Remove section about linear model deprecation
lesteve cabac1b
rst seriously?
lesteve ab68a91
small tweak [azure parallel]
lesteve 8ebcc57
typo
lesteve 8134033
[azure parallel]
lesteve 2e45eab
Apply suggestions from code review
lesteve a0418de
[azure parallel]
lesteve 04364bc
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve d5599a7
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve 53c57e9
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve 9812187
[azure parallel]
lesteve ab7cc83
Update examples/release_highlights/plot_release_highlights_1_8_0.py
lesteve a1df531
[azure parallel]
lesteve ed9513b
Fix URL
betatim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
288 changes: 288 additions & 0 deletions
288
examples/release_highlights/plot_release_highlights_1_8_0.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,288 @@ | ||
| # ruff: noqa: CPY001 | ||
| """ | ||
| ======================================= | ||
| Release Highlights for scikit-learn 1.8 | ||
| ======================================= | ||
|
|
||
| .. currentmodule:: sklearn | ||
|
|
||
| We are pleased to announce the release of scikit-learn 1.8! Many bug fixes | ||
| and improvements were added, as well as some key new features. Below we | ||
| detail the highlights of this release. **For an exhaustive list of | ||
| all the changes**, please refer to the :ref:`release notes <release_notes_1_8>`. | ||
|
|
||
| To install the latest version (with pip):: | ||
|
|
||
| pip install --upgrade scikit-learn | ||
|
|
||
| or with conda:: | ||
|
|
||
| conda install -c conda-forge scikit-learn | ||
|
|
||
| """ | ||
|
|
||
| # %% | ||
| # Array API support (enables GPU computations) | ||
| # -------------------------------------------- | ||
| # The progressive adoption of the Python array API standard in | ||
| # scikit-learn means that PyTorch and CuPy input arrays | ||
| # are used directly. This means that in scikit-learn estimators | ||
| # and functions non-CPU devices, such as GPUs, can be used | ||
| # to perform the computation. As a result performance is improved | ||
| # and integration with these libraries is easier. | ||
| # | ||
| # In scikit-learn 1.8, several estimators and functions have been updated to | ||
| # support array API compatible inputs, for example PyTorch tensors and CuPy | ||
| # arrays. | ||
| # | ||
| # Array API support was added to the following estimators: | ||
| # :class:`preprocessing.StandardScaler`, | ||
| # :class:`preprocessing.PolynomialFeatures`, :class:`linear_model.RidgeCV`, | ||
| # :class:`linear_model.RidgeClassifierCV`, :class:`mixture.GaussianMixture` and | ||
| # :class:`calibration.CalibratedClassifierCV`. | ||
| # | ||
| # Array API support was also added to several metrics in :mod:`sklearn.metrics` | ||
| # module, see :ref:`array_api_supported` for more details. | ||
| # | ||
| # Please refer to the :ref:`array API support<array_api>` page for instructions | ||
| # to use scikit-learn with array API compatible libraries such as PyTorch or CuPy. | ||
| # Note: Array API support is experimental and must be explicitly enabled both | ||
| # in SciPy and scikit-learn. | ||
| # | ||
| # Here is an excerpt of using a feature engineering preprocessor on the CPU, | ||
| # followed by :class:`calibration.CalibratedClassifierCV` | ||
| # and :class:`linear_model.RidgeCV` together on a GPU with the help of PyTorch: | ||
| # | ||
| # .. code-block:: python | ||
| # | ||
| # ridge_pipeline_gpu = make_pipeline( | ||
| # # Ensure that all features (including categorical features) are preprocessed | ||
| # # on the CPU and mapped to a numerical representation. | ||
| # feature_preprocessor, | ||
| # # Move the results to the GPU and perform computations there | ||
| # FunctionTransformer( | ||
lesteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # lambda x: torch.tensor(x.to_numpy().astype(np.float32), device="cuda")) | ||
| # , | ||
| # CalibratedClassifierCV( | ||
| # RidgeClassifierCV(alphas=alphas), method="temperature" | ||
| # ), | ||
| # ) | ||
| # with sklearn.config_context(array_api_dispatch=True): | ||
| # cv_results = cross_validate(ridge_pipeline_gpu, features, target) | ||
| # | ||
| # | ||
| # See the `full notebook on Google Colab | ||
| # <https://colab.research.google.com/drive/1ztH8gUPv31hSjEeR_8pw20qShTwViGRx?usp=sharing>`_ | ||
| # for more details. On this particular example, using the Colab GPU vs using a | ||
| # single CPU core leads to a 10x speedup which is quite typical for such workloads. | ||
|
|
||
| # %% | ||
| # Free-threaded CPython 3.14 support | ||
| # ---------------------------------- | ||
| # | ||
| # scikit-learn has support for free-threaded CPython, in particular | ||
| # free-threaded wheels are available for all of our supported platforms on Python | ||
| # 3.14. | ||
lesteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # | ||
| # We would be very interested by user feedback. Here are a few things you can | ||
| # try: | ||
| # | ||
| # - install free-threaded CPython 3.14, run your favourite | ||
| # scikit-learn script and check that nothing breaks unexpectedly. | ||
| # Note that CPython 3.14 (rather than 3.13) is strongly advised because a | ||
| # number of free-threaded bugs have been fixed since CPython 3.13. | ||
| # - if you use some estimators with a `n_jobs` parameter, try changing the | ||
| # default backend to threading with `joblib.parallel_config` as in the | ||
| # snippet below. This could potentially speed-up your code because the | ||
| # default joblib backend is process-based and incurs more overhead than | ||
| # threads. | ||
| # | ||
| # .. code-block:: python | ||
| # | ||
| # grid_search = GridSearchCV(clf, param_grid=param_grid, n_jobs=4) | ||
| # with joblib.parallel_config(backend="threading"): | ||
| # grid_search.fit(X, y) | ||
| # | ||
| # - don't hesitate to report any issue or unexpected performance behaviour by | ||
| # opening a `GitHub issue <https://github.com/scikit-learn/scikit-learn/issues/new/choose>`_! | ||
| # | ||
| # Free-threaded (also known as nogil) CPython is a version of CPython that aims | ||
| # to enable efficient multi-threaded use cases by removing the Global | ||
| # Interpreter Lock (GIL). | ||
| # | ||
| # For more details about free-threaded CPython see `py-free-threading doc | ||
| # <https://py-free-threading.github.io>`_, in particular `how to install a | ||
| # free-threaded CPython <https://py-free-threading.github.io/installing-cpython/>`_ | ||
| # and `Ecosystem compatibility tracking <https://py-free-threading.github.io/tracking/>`_. | ||
lesteve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # | ||
| # In scikit-learn, one hope with free-threaded Python is to more efficiently | ||
| # leverage multi-core CPUs by using thread workers instead of subprocess | ||
| # workers for parallel computation when passing `n_jobs>1` in functions or | ||
| # estimators. Efficiency gains are expected by removing the need for | ||
| # inter-process communication. Be aware that switching the default joblib | ||
| # backend and testing that everything works well with free-threaded Python is an | ||
| # ongoing long-term effort. | ||
|
|
||
| # %% | ||
| # Temperature scaling in `CalibratedClassifierCV` | ||
| # ----------------------------------------------- | ||
| # Probability calibration of classifiers with temperature scaling is available in | ||
| # :class:`calibration.CalibratedClassifierCV` by setting `method="temperature"`. | ||
| # This method is particularly well suited for multiclass problems because it provides | ||
| # (better) calibrated probabilities with a single free parameter. This is in | ||
| # contrast to all the other available calibrations methods | ||
| # which use a "One-vs-Rest" scheme that adds more parameters for each class. | ||
|
|
||
| from sklearn.calibration import CalibratedClassifierCV | ||
| from sklearn.datasets import make_classification | ||
| from sklearn.naive_bayes import GaussianNB | ||
|
|
||
| X, y = make_classification(n_classes=3, n_informative=8, random_state=42) | ||
| clf = GaussianNB().fit(X, y) | ||
| sig = CalibratedClassifierCV(clf, method="sigmoid", ensemble=False).fit(X, y) | ||
| ts = CalibratedClassifierCV(clf, method="temperature", ensemble=False).fit(X, y) | ||
|
|
||
| # %% | ||
| # The following example shows that temperature scaling can produce better calibrated | ||
| # probabilities than sigmoid calibration in multi-class classification problem | ||
| # with 3 classes. | ||
|
|
||
| import matplotlib.pyplot as plt | ||
|
|
||
| from sklearn.calibration import CalibrationDisplay | ||
|
|
||
| fig, axes = plt.subplots( | ||
| figsize=(8, 4.5), | ||
| ncols=3, | ||
| sharey=True, | ||
| ) | ||
| for i, c in enumerate(ts.classes_): | ||
| CalibrationDisplay.from_predictions( | ||
| y == c, clf.predict_proba(X)[:, i], name="Uncalibrated", ax=axes[i], marker="s" | ||
| ) | ||
| CalibrationDisplay.from_predictions( | ||
| y == c, | ||
| ts.predict_proba(X)[:, i], | ||
| name="Temperature scaling", | ||
| ax=axes[i], | ||
| marker="o", | ||
| ) | ||
| CalibrationDisplay.from_predictions( | ||
| y == c, sig.predict_proba(X)[:, i], name="Sigmoid", ax=axes[i], marker="v" | ||
| ) | ||
| axes[i].set_title(f"Class {c}") | ||
| axes[i].set_xlabel(None) | ||
| axes[i].set_ylabel(None) | ||
| axes[i].get_legend().remove() | ||
| fig.suptitle("Reliability Diagrams per Class") | ||
| fig.supxlabel("Mean Predicted Probability") | ||
| fig.supylabel("Fraction of Class") | ||
| fig.legend(*axes[0].get_legend_handles_labels(), loc=(0.72, 0.5)) | ||
| plt.subplots_adjust(right=0.7) | ||
| _ = fig.show() | ||
|
|
||
| # %% | ||
| # Efficiency improvements in linear models | ||
| # ---------------------------------------- | ||
| # The fit time has been massively reduced for squared error based estimators | ||
| # with L1 penalty: `ElasticNet`, `Lasso`, `MultiTaskElasticNet`, | ||
| # `MultiTaskLasso` and their CV variants. The fit time improvement is mainly | ||
| # achieved by **gap safe screening rules**. They enable the coordinate descent | ||
| # solver to set feature coefficients to zero early on and not look at them | ||
| # again. The stronger the L1 penalty the earlier features can be excluded from | ||
| # further updates. | ||
|
|
||
| from time import time | ||
|
|
||
| from sklearn.datasets import make_regression | ||
| from sklearn.linear_model import ElasticNetCV | ||
|
|
||
| X, y = make_regression(n_features=10_000, random_state=0) | ||
| model = ElasticNetCV() | ||
| tic = time() | ||
| model.fit(X, y) | ||
| toc = time() | ||
| print(f"Fitting ElasticNetCV took {toc - tic:.3} seconds.") | ||
|
|
||
| # %% | ||
| # HTML representation of estimators | ||
| # --------------------------------- | ||
| # Hyperparameters in the dropdown table of the HTML representation now include | ||
| # links to the online documentation. Docstring descriptions are also shown as | ||
| # tooltips on hover. | ||
|
|
||
| from sklearn.linear_model import LogisticRegression | ||
| from sklearn.pipeline import make_pipeline | ||
| from sklearn.preprocessing import StandardScaler | ||
|
|
||
| clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0, C=10)) | ||
|
|
||
| # %% | ||
| # Expand the estimator diagram below by clicking on "LogisticRegression" and then on | ||
| # "Parameters". | ||
|
|
||
| clf | ||
|
|
||
|
|
||
| # %% | ||
| # DecisionTreeRegressor with `criterion="absolute_error"` | ||
| # ------------------------------------------------------ | ||
| # :class:`tree.DecisionTreeRegressor` with `criterion="absolute_error"` | ||
| # now runs much faster. It has now `O(n * log(n))` complexity compared to | ||
| # `O(n**2)` previously, which allows to scale to millions of data points. | ||
| # | ||
| # As an illustration, on a dataset with 100_000 samples and 1 feature, doing a | ||
| # single split takes of the order of 100 ms, compared to ~20 seconds before. | ||
|
|
||
| import time | ||
|
|
||
| from sklearn.datasets import make_regression | ||
| from sklearn.tree import DecisionTreeRegressor | ||
|
|
||
| X, y = make_regression(n_samples=100_000, n_features=1) | ||
| tree = DecisionTreeRegressor(criterion="absolute_error", max_depth=1) | ||
|
|
||
| tic = time.time() | ||
| tree.fit(X, y) | ||
| elapsed = time.time() - tic | ||
| print(f"Fit took {elapsed:.2f} seconds") | ||
|
|
||
| # %% | ||
| # ClassicalMDS | ||
| # ------------ | ||
| # Classical MDS, also known as "Principal Coordinates Analysis" (PCoA) | ||
| # or "Torgerson's scaling" is now available within the `sklearn.manifold` | ||
| # module. Classical MDS is close to PCA and instead of approximating | ||
| # distances, it approximates pairwise scalar products, which has an exact | ||
| # analytic solution in terms of eigendecomposition. | ||
| # | ||
| # Let's illustrate this new addition by using it on an S-curve dataset to | ||
| # get a low-dimensional representation of the data. | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| from matplotlib import ticker | ||
|
|
||
| from sklearn import datasets, manifold | ||
|
|
||
| n_samples = 1500 | ||
| S_points, S_color = datasets.make_s_curve(n_samples, random_state=0) | ||
| md_classical = manifold.ClassicalMDS(n_components=2) | ||
| S_scaling = md_classical.fit_transform(S_points) | ||
|
|
||
| fig = plt.figure(figsize=(8, 4)) | ||
| ax1 = fig.add_subplot(1, 2, 1, projection="3d") | ||
| x, y, z = S_points.T | ||
| ax1.scatter(x, y, z, c=S_color, s=50, alpha=0.8) | ||
| ax1.set_title("Original S-curve samples", size=16) | ||
| ax1.view_init(azim=-60, elev=9) | ||
| for axis in (ax1.xaxis, ax1.yaxis, ax1.zaxis): | ||
| axis.set_major_locator(ticker.MultipleLocator(1)) | ||
|
|
||
| ax2 = fig.add_subplot(1, 2, 2) | ||
| x2, y2 = S_scaling.T | ||
| ax2.scatter(x2, y2, c=S_color, s=50, alpha=0.8) | ||
| ax2.set_title("Classical MDS", size=16) | ||
| for axis in (ax2.xaxis, ax2.yaxis): | ||
| axis.set_major_formatter(ticker.NullFormatter()) | ||
|
|
||
| plt.show() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.