From 8dc719b7885023c9215b93db1061b43fc0fa3d6f Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 14 May 2024 15:03:38 +0200 Subject: [PATCH 01/13] common test + first applications --- sklearn/cluster/_affinity_propagation.py | 9 ++---- sklearn/decomposition/_factor_analysis.py | 2 +- sklearn/impute/_base.py | 1 + sklearn/tests/test_common.py | 25 +++++++++++++++ sklearn/utils/estimator_checks.py | 38 +++++++++++++++++++++++ sklearn/utils/validation.py | 33 +++++++++++++------- 6 files changed, 90 insertions(+), 18 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 735e30d3ea4b2..ef4d07c9e598d 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -504,13 +504,10 @@ def fit(self, X, y=None): Returns the instance itself. """ if self.affinity == "precomputed": - accept_sparse = False - else: - accept_sparse = "csr" - X = self._validate_data(X, accept_sparse=accept_sparse) - if self.affinity == "precomputed": - self.affinity_matrix_ = X.copy() if self.copy else X + X = self._validate_data(X, copy=self.copy, writeable=True) + self.affinity_matrix_ = X else: # self.affinity == "euclidean" + X = self._validate_data(X, accept_sparse="csr") self.affinity_matrix_ = -euclidean_distances(X, squared=True) if self.affinity_matrix_.shape[0] != self.affinity_matrix_.shape[1]: diff --git a/sklearn/decomposition/_factor_analysis.py b/sklearn/decomposition/_factor_analysis.py index af3498d534483..7ca2167aa6f72 100644 --- a/sklearn/decomposition/_factor_analysis.py +++ b/sklearn/decomposition/_factor_analysis.py @@ -219,7 +219,7 @@ def fit(self, X, y=None): self : object FactorAnalysis class instance. """ - X = self._validate_data(X, copy=self.copy, dtype=np.float64) + X = self._validate_data(X, copy=self.copy, dtype=np.float64, writeable=True) n_samples, n_features = X.shape n_components = self.n_components diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 04a4dffd10e68..2e303c0c9c818 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -334,6 +334,7 @@ def _validate_input(self, X, in_fit): reset=in_fit, accept_sparse="csc", dtype=dtype, + writeable=True if not in_fit else None, force_all_finite=force_all_finite, copy=self.copy, ) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9ff83953f4b0e..89048b449c050 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -79,6 +79,7 @@ check_get_feature_names_out_error, check_global_output_transform_pandas, check_global_set_output_transform_polars, + check_inplace_ensure_writeable, check_n_features_in_after_fitting, check_param_validation, check_set_output_transform, @@ -624,3 +625,27 @@ def test_set_output_transform_configured(estimator, check_func): _set_checking_parameters(estimator) with ignore_warnings(category=(FutureWarning)): check_func(estimator.__class__.__name__, estimator) + + +@pytest.mark.parametrize( + "estimator", _tested_estimators(), ids=_get_check_estimator_ids +) +def test_check_inplace_ensure_writeable(estimator): + if hasattr(estimator, "copy"): + estimator.set_params(copy=False) + elif hasattr(estimator, "copy_X"): + estimator.set_params(copy_X=False) + else: + raise SkipTest("Estimator doesn't require writeable input.") + + _set_checking_parameters(estimator) + + # The following estimators can work inplace only with certain settings + if estimator.__class__.__name__ == "HDBSCAN": + estimator.set_params(metric="precomputed") + estimator.set_params(algorithm="brute") + + if estimator.__class__.__name__ == "PCA": + estimator.set_params(svd_solver="full") + + check_inplace_ensure_writeable(estimator.__class__.__name__, estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 59d371bad57cd..f91546c7e038c 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4730,3 +4730,41 @@ def check_set_output_transform_polars(name, transformer_orig): def check_global_set_output_transform_polars(name, transformer_orig): _check_set_output_transform_polars_context(name, transformer_orig, "global") + + +@ignore_warnings(category=FutureWarning) +def check_inplace_ensure_writeable(name, estimator_orig): + """Check that estimators able to do inplace operations can work on read-only + input data even if a copy is not explicitly requested by the user. + """ + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + + n_samples = 100 + + X, _ = make_blobs(n_samples=n_samples, n_features=3, random_state=rng) + X = _enforce_estimator_tags_X(estimator, X) + + # These estimators can only work inplace with fortran ordered input + if name in ("Lasso", "ElasticNet", "MultiTaskElasticNet", "MultiTaskLasso"): + X = np.asfortranarray(X) + + # Add a missing value for imputers so that transform has to do something + if hasattr(estimator, "missing_values"): + X[0, 0] = np.nan + + if is_regressor(estimator): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + # Make X read-only + X.setflags(write=False) + + estimator.fit(X, y) + + if hasattr(estimator, "transform"): + estimator.transform(X) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 5fac2ae6ae6c2..40835c3f50383 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -725,6 +725,7 @@ def check_array( accept_large_sparse=True, dtype="numeric", order=None, + writeable=None, copy=False, force_all_finite=True, ensure_2d=True, @@ -772,6 +773,13 @@ def check_array( the memory layout of the returned array is kept as close as possible to the original array. + writeable : True or None, default=None + Whether the returned array will be writeable. If True, the returned array is + guaranteed to be writeable, which may require a copy. If None, the writeability + of the input array is preserved. + + .. versionadded:: 1.6 + copy : bool, default=False Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion. @@ -1087,17 +1095,11 @@ def is_sparse(dtype): % (n_features, array.shape, ensure_min_features, context) ) - # With an input pandas dataframe or series, we know we can always make the - # resulting array writeable: - # - if copy=True, we have already made a copy so it is fine to make the - # array writeable - # - if copy=False, the caller is telling us explicitly that we can do - # in-place modifications - # See https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html#read-only-numpy-arrays - # for more details about pandas copy-on-write mechanism, that is enabled by - # default in pandas 3.0.0.dev. - if _is_pandas_df_or_series(array_orig) and hasattr(array, "flags"): - array.flags.writeable = True + if writeable and not array.flags.writeable: + try: + array.setflags(write=True) + except Exception: + array = array.copy() return array @@ -1132,6 +1134,7 @@ def check_X_y( accept_large_sparse=True, dtype="numeric", order=None, + writeable=None, copy=False, force_all_finite=True, ensure_2d=True, @@ -1183,6 +1186,13 @@ def check_X_y( Whether an array will be forced to be fortran or c-style. If `None`, then the input data's order is preserved when possible. + writeable : True or None, default=None + Whether the returned array will be writeable. If True, the returned array is + guaranteed to be writeable, which may require a copy. If None, the writeability + of the input array is preserved. + + .. versionadded:: 1.6 + copy : bool, default=False Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion. @@ -1269,6 +1279,7 @@ def check_X_y( accept_large_sparse=accept_large_sparse, dtype=dtype, order=order, + writeable=writeable, copy=copy, force_all_finite=force_all_finite, ensure_2d=ensure_2d, From 5bda2c3f9bc76f034fa5240d2f9121c80f462412 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 14 May 2024 16:34:14 +0200 Subject: [PATCH 02/13] include sparse --- sklearn/utils/validation.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 40835c3f50383..69dcf445156bf 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1095,11 +1095,19 @@ def is_sparse(dtype): % (n_features, array.shape, ensure_min_features, context) ) - if writeable and not array.flags.writeable: - try: - array.setflags(write=True) - except Exception: - array = array.copy() + if writeable: + if sp.issparse(array) and not array.data.flags.writeable: + try: + array.data.setflags(write=True) + array.indptr.setflags(write=True) + array.indices.setflags(write=True) + except Exception: + array = array.copy() + elif not sp.issparse(array) and not array.flags.writeable: + try: + array.setflags(write=True) + except Exception: + array = array.copy() return array From 98010b3171f00da52f02153f6018f76a8e9c158f Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 14 May 2024 17:13:33 +0200 Subject: [PATCH 03/13] simpler --- sklearn/utils/validation.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 69dcf445156bf..5839c261681a9 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1096,16 +1096,10 @@ def is_sparse(dtype): ) if writeable: - if sp.issparse(array) and not array.data.flags.writeable: + array_data = array.data if sp.issparse(array) else array + if not array_data.flags.writeable: try: - array.data.setflags(write=True) - array.indptr.setflags(write=True) - array.indices.setflags(write=True) - except Exception: - array = array.copy() - elif not sp.issparse(array) and not array.flags.writeable: - try: - array.setflags(write=True) + array_data.setflags(write=True) except Exception: array = array.copy() From 50d7cd098bc06cb657f3abf30cb8b1998c53da11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 23 May 2024 15:14:03 +0200 Subject: [PATCH 04/13] always copy when writeable + read-only but for 1 pandas exception --- sklearn/utils/tests/test_validation.py | 69 ++++++++++++++++++++++++++ sklearn/utils/validation.py | 20 ++++++-- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 92fff950e875e..f146d2f7527b3 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2,6 +2,7 @@ import numbers import re +import tempfile import warnings from itertools import product from operator import itemgetter @@ -2124,3 +2125,71 @@ def __init__(self): self.schema = ["a", "b"] assert not _is_polars_df(LooksLikePolars()) + + +def test_check_array_writeable_np(): + """Check the behavior of check_array when a writeable array is requested + without copy if possible, on numpy arrays. + """ + X = np.random.uniform(size=(10, 10)) + + out = check_array(X, copy=False, writeable=True) + # X is already writeable, no copy is needed + assert np.may_share_memory(out, X) + assert out.flags.writeable + + X.flags.writeable = False + + out = check_array(X, copy=False, writeable=True) + # X is not writeable, a copy is made + assert not np.may_share_memory(out, X) + assert out.flags.writeable + + +def test_check_array_writeable_mmap(): + """Check the behavior of check_array when a writeable array is requested + without copy if possible, on a memory-map. + + A common situation is when a meta-estimators run in parallel using multiprocessing + with joblib, which creates read-only memory-maps of large arrays. + """ + X = np.random.uniform(size=(10, 10)) + + with tempfile.NamedTemporaryFile() as f: + mmap = np.memmap(f.name, dtype="float64", mode="w+", shape=(10, 10)) + mmap[:] = X[:] + + out = check_array(mmap, copy=False, writeable=True) + # mmap is already writeable, no copy is needed + assert np.may_share_memory(out, mmap) + assert out.flags.writeable + + mmap = np.memmap(f.name, dtype="float64", mode="r", shape=(10, 10)) + + out = check_array(mmap, copy=False, writeable=True) + # mmap is read-only, a copy is made + assert not np.may_share_memory(out, mmap) + assert out.flags.writeable + + +def test_check_array_writeable_df(): + """Check the behavior of check_array when a writeable array is requested + without copy if possible, on a dataframe. + """ + pd = pytest.importorskip("pandas") + + X = np.random.uniform(size=(10, 10)) + df = pd.DataFrame(X, copy=False) + + out = check_array(df, copy=False, writeable=True) + # df is backed by a writeable array, no copy is needed + assert np.may_share_memory(out, df) + assert out.flags.writeable + + X.flags.writeable = False + df = pd.DataFrame(X, copy=False) + + out = check_array(df, copy=False, writeable=True) + # df is backed by a read-only array, a copy is made + assert not np.may_share_memory(out, df) + assert out.flags.writeable diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 01e1926778c08..e374636e1aa28 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1102,9 +1102,23 @@ def is_sparse(dtype): if writeable: array_data = array.data if sp.issparse(array) else array if not array_data.flags.writeable: - try: - array_data.setflags(write=True) - except Exception: + # This situation can only happen when copy=False, the array is read-only and + # a writeable output is requested. This is an ambiguous setting so we chose + # to always (except for one specific setting, see below) make a copy to + # ensure that the output is writeable, even if avoidable, to not overwrite + # the user's data by surprise. + + if _is_pandas_df_or_series(array_orig): + try: + # In pandas >= 3, np.asarray(df), called earlier in check_array, + # returns a read-only intermediate array. It can be made writeable + # safely without copy because if the original DataFrame was backed + # by a read-only array, trying to change the flag would raise an + # error, in which case we make a copy. + array_data.flags.writeable = True + except ValueError: + array = array.copy() + else: array = array.copy() return array From a8b56a74f74ac025c00a86b867b4f1b075cc3f16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 23 May 2024 15:18:47 +0200 Subject: [PATCH 05/13] nit --- sklearn/tests/test_common.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 89048b449c050..2ddb82a9bc762 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -631,21 +631,22 @@ def test_set_output_transform_configured(estimator, check_func): "estimator", _tested_estimators(), ids=_get_check_estimator_ids ) def test_check_inplace_ensure_writeable(estimator): + name = estimator.__class__.__name__ + if hasattr(estimator, "copy"): estimator.set_params(copy=False) elif hasattr(estimator, "copy_X"): estimator.set_params(copy_X=False) else: - raise SkipTest("Estimator doesn't require writeable input.") + raise SkipTest(f"{name} doesn't require writeable input.") _set_checking_parameters(estimator) # The following estimators can work inplace only with certain settings - if estimator.__class__.__name__ == "HDBSCAN": - estimator.set_params(metric="precomputed") - estimator.set_params(algorithm="brute") + if name == "HDBSCAN": + estimator.set_params(metric="precomputed", algorithm="brute") - if estimator.__class__.__name__ == "PCA": + if name == "PCA": estimator.set_params(svd_solver="full") - check_inplace_ensure_writeable(estimator.__class__.__name__, estimator) + check_inplace_ensure_writeable(name, estimator) From 23cda9c4d51255f318bfe905aaf3bbb3ad701322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 24 May 2024 17:10:11 +0200 Subject: [PATCH 06/13] wip --- sklearn/decomposition/_pca.py | 1 + sklearn/impute/_knn.py | 1 + sklearn/linear_model/_base.py | 7 ++++++- sklearn/linear_model/_bayes.py | 7 ++++++- sklearn/linear_model/_ridge.py | 2 ++ sklearn/preprocessing/_data.py | 23 +++++++++++++++++++---- sklearn/tests/test_common.py | 3 +++ sklearn/utils/estimator_checks.py | 14 +++++++++++++- 8 files changed, 51 insertions(+), 7 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index cb0f2e7e02fb3..cece6343bd355 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -511,6 +511,7 @@ def _fit(self, X): X = self._validate_data( X, dtype=[xp.float64, xp.float32], + writeable=True, accept_sparse=("csr", "csc"), ensure_2d=True, copy=False, diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index 64f55693356d6..03d07967b7034 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -267,6 +267,7 @@ def transform(self, X): X, accept_sparse=False, dtype=FLOAT_DTYPES, + writeable=True, force_all_finite=force_all_finite, copy=self.copy, reset=False, diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index eac754f3f88b4..faaa4d6734967 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -607,7 +607,12 @@ def fit(self, X, y, sample_weight=None): accept_sparse = False if self.positive else ["csr", "csc", "coo"] X, y = self._validate_data( - X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True + X, + y, + accept_sparse=accept_sparse, + y_numeric=True, + multi_output=True, + writeable=True, ) has_sw = sample_weight is not None diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index a572c82e6e158..cd678cf44599d 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -620,7 +620,12 @@ def fit(self, X, y): Fitted estimator. """ X, y = self._validate_data( - X, y, dtype=[np.float64, np.float32], y_numeric=True, ensure_min_samples=2 + X, + y, + dtype=[np.float64, np.float32], + writeable=True, + y_numeric=True, + ensure_min_samples=2, ) dtype = X.dtype diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index b336565cff1f6..e8094a72040b1 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1244,6 +1244,7 @@ def fit(self, X, y, sample_weight=None): y, accept_sparse=_accept_sparse, dtype=[xp.float64, xp.float32], + writeable=True, multi_output=True, y_numeric=True, ) @@ -1293,6 +1294,7 @@ def _prepare_data(self, X, y, sample_weight, solver): accept_sparse=accept_sparse, multi_output=True, y_numeric=False, + writeable=True, ) self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index 6dad8dc1c8c21..ac73b4003c5f8 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -535,6 +535,7 @@ def transform(self, X): X, copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), + writeable=True, force_all_finite="allow-nan", reset=False, ) @@ -566,6 +567,7 @@ def inverse_transform(self, X): X, copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), + writeable=True, force_all_finite="allow-nan", ) @@ -1046,6 +1048,7 @@ def transform(self, X, copy=None): accept_sparse="csr", copy=copy, dtype=FLOAT_DTYPES, + writeable=True, force_all_finite="allow-nan", ) @@ -1087,6 +1090,7 @@ def inverse_transform(self, X, copy=None): accept_sparse="csr", copy=copy, dtype=FLOAT_DTYPES, + writeable=True, force_all_finite="allow-nan", ) @@ -1291,6 +1295,7 @@ def transform(self, X): copy=self.copy, reset=False, dtype=_array_api.supported_float_dtypes(xp), + writeable=True, force_all_finite="allow-nan", ) @@ -1322,6 +1327,7 @@ def inverse_transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), + writeable=True, force_all_finite="allow-nan", ) @@ -1654,6 +1660,7 @@ def transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=FLOAT_DTYPES, + writeable=True, reset=False, force_all_finite="allow-nan", ) @@ -1687,6 +1694,7 @@ def inverse_transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=FLOAT_DTYPES, + writeable=True, force_all_finite="allow-nan", ) @@ -1928,6 +1936,7 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False): copy=copy, estimator="the normalize function", dtype=_array_api.supported_float_dtypes(xp), + writeable=True, ) if axis == 0: X = X.T @@ -2091,8 +2100,10 @@ def transform(self, X, copy=None): Transformed array. """ copy = copy if copy is not None else self.copy - X = self._validate_data(X, accept_sparse="csr", reset=False) - return normalize(X, norm=self.norm, axis=1, copy=copy) + X = self._validate_data( + X, accept_sparse="csr", writeable=True, copy=copy, reset=False + ) + return normalize(X, norm=self.norm, axis=1, copy=False) def _more_tags(self): return {"stateless": True, "array_api_support": True} @@ -2146,7 +2157,7 @@ def binarize(X, *, threshold=0.0, copy=True): array([[0., 1., 0.], [1., 0., 0.]]) """ - X = check_array(X, accept_sparse=["csr", "csc"], copy=copy) + X = check_array(X, accept_sparse=["csr", "csc"], writeable=True, copy=copy) if sparse.issparse(X): if threshold < 0: raise ValueError("Cannot binarize a sparse matrix with threshold < 0") @@ -2287,7 +2298,9 @@ def transform(self, X, copy=None): copy = copy if copy is not None else self.copy # TODO: This should be refactored because binarize also calls # check_array - X = self._validate_data(X, accept_sparse=["csr", "csc"], copy=copy, reset=False) + X = self._validate_data( + X, accept_sparse=["csr", "csc"], writeable=True, copy=copy, reset=False + ) return binarize(X, threshold=self.threshold, copy=False) def _more_tags(self): @@ -2852,6 +2865,7 @@ def _check_inputs(self, X, in_fit, accept_sparse_negative=False, copy=False): accept_sparse="csc", copy=copy, dtype=FLOAT_DTYPES, + writeable=True if not in_fit else None, force_all_finite="allow-nan", ) # we only accept positive sparse matrix when ignore_implicit_zeros is @@ -3490,6 +3504,7 @@ def _check_input(self, X, in_fit, check_positive=False, check_shape=False): X, ensure_2d=True, dtype=FLOAT_DTYPES, + writeable=True, copy=self.copy, force_all_finite="allow-nan", reset=in_fit, diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 2e73722afe5f7..fe3c1a8181173 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -642,4 +642,7 @@ def test_check_inplace_ensure_writeable(estimator): if name == "PCA": estimator.set_params(svd_solver="full") + if name == "KernelPCA": + estimator.set_params(kernel="precomputed") + check_inplace_ensure_writeable(name, estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index edd65dd1afcc9..85efe74b0cda5 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4733,6 +4733,9 @@ def check_global_set_output_transform_polars(name, transformer_orig): def check_inplace_ensure_writeable(name, estimator_orig): """Check that estimators able to do inplace operations can work on read-only input data even if a copy is not explicitly requested by the user. + + Make sure that a copy is made and consequently that the input array and its + writeability are not modified by the estimator. """ rng = np.random.RandomState(0) @@ -4758,10 +4761,19 @@ def check_inplace_ensure_writeable(name, estimator_orig): y = rng.randint(low=0, high=2, size=n_samples) y = _enforce_estimator_tags_y(estimator, y) + X_copy = X.copy() + # Make X read-only X.setflags(write=False) estimator.fit(X, y) if hasattr(estimator, "transform"): - estimator.transform(X) + Xt = estimator.transform(X) + + if hasattr(estimator, "inverse_transform"): + Xt.flags.writeable = False + estimator.inverse_transform(Xt) + + assert not X.flags.writeable + assert_allclose(X, X_copy) From 08cbaee27ac9c3ff10d8eeae596824a2e4c8606b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 24 May 2024 18:32:20 +0200 Subject: [PATCH 07/13] add writeable to estimators with inplace operations --- sklearn/cluster/_hdbscan/hdbscan.py | 5 ++++- sklearn/cross_decomposition/_pls.py | 18 ++++++++++++++---- sklearn/decomposition/_incremental_pca.py | 7 ++++++- sklearn/linear_model/_bayes.py | 4 +++- sklearn/linear_model/_coordinate_descent.py | 4 ++++ sklearn/linear_model/_least_angle.py | 8 +++++--- sklearn/utils/estimator_checks.py | 6 +----- sklearn/utils/validation.py | 4 ++-- 8 files changed, 39 insertions(+), 17 deletions(-) diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index 9933318313cc8..f62640b833bcf 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -770,6 +770,7 @@ def fit(self, X, y=None): X, accept_sparse=["csr", "lil"], dtype=np.float64, + writeable=True, ) else: # Only non-sparse, precomputed distance matrices are handled here @@ -777,7 +778,9 @@ def fit(self, X, y=None): # Perform data validation after removing infinite values (numpy.inf) # from the given distance matrix. - X = self._validate_data(X, force_all_finite=False, dtype=np.float64) + X = self._validate_data( + X, force_all_finite=False, dtype=np.float64, writeable=True + ) if np.isnan(X).any(): # TODO: Support np.nan in Cython implementation for precomputed # dense HDBSCAN diff --git a/sklearn/cross_decomposition/_pls.py b/sklearn/cross_decomposition/_pls.py index b6f7dd663724e..2786fee1cac79 100644 --- a/sklearn/cross_decomposition/_pls.py +++ b/sklearn/cross_decomposition/_pls.py @@ -263,10 +263,15 @@ def fit(self, X, y=None, Y=None): check_consistent_length(X, y) X = self._validate_data( - X, dtype=np.float64, copy=self.copy, ensure_min_samples=2 + X, dtype=np.float64, writeable=True, copy=self.copy, ensure_min_samples=2 ) y = check_array( - y, input_name="y", dtype=np.float64, copy=self.copy, ensure_2d=False + y, + input_name="y", + dtype=np.float64, + writeable=True, + copy=self.copy, + ensure_2d=False, ) if y.ndim == 1: self._predict_1d = True @@ -1056,10 +1061,15 @@ def fit(self, X, y=None, Y=None): y = _deprecate_Y_when_required(y, Y) check_consistent_length(X, y) X = self._validate_data( - X, dtype=np.float64, copy=self.copy, ensure_min_samples=2 + X, dtype=np.float64, writeable=True, copy=self.copy, ensure_min_samples=2 ) y = check_array( - y, input_name="y", dtype=np.float64, copy=self.copy, ensure_2d=False + y, + input_name="y", + dtype=np.float64, + writeable=True, + copy=self.copy, + ensure_2d=False, ) if y.ndim == 1: y = y.reshape(-1, 1) diff --git a/sklearn/decomposition/_incremental_pca.py b/sklearn/decomposition/_incremental_pca.py index 1089b2c54e086..2d2a802e026cd 100644 --- a/sklearn/decomposition/_incremental_pca.py +++ b/sklearn/decomposition/_incremental_pca.py @@ -229,6 +229,7 @@ def fit(self, X, y=None): accept_sparse=["csr", "csc", "lil"], copy=self.copy, dtype=[np.float64, np.float32], + writeable=True, ) n_samples, n_features = X.shape @@ -278,7 +279,11 @@ def partial_fit(self, X, y=None, check_input=True): "or use IncrementalPCA.fit to do so in batches." ) X = self._validate_data( - X, copy=self.copy, dtype=[np.float64, np.float32], reset=first_pass + X, + copy=self.copy, + dtype=[np.float64, np.float32], + writeable=True, + reset=first_pass, ) n_samples, n_features = X.shape if first_pass: diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index cd678cf44599d..a2ec20eda78da 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -235,7 +235,9 @@ def fit(self, X, y, sample_weight=None): self : object Returns the instance itself. """ - X, y = self._validate_data(X, y, dtype=[np.float64, np.float32], y_numeric=True) + X, y = self._validate_data( + X, y, dtype=[np.float64, np.float32], writeable=True, y_numeric=True + ) dtype = X.dtype if sample_weight is not None: diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 6a62fa1e245e2..2c46ab11efee3 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -983,6 +983,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): accept_sparse="csc", order="F", dtype=[np.float64, np.float32], + writeable=True, accept_large_sparse=False, copy=X_copied, multi_output=True, @@ -1611,6 +1612,7 @@ def fit(self, X, y, sample_weight=None, **params): check_X_params = dict( accept_sparse="csc", dtype=[np.float64, np.float32], + writeable=True, copy=False, accept_large_sparse=False, ) @@ -1636,6 +1638,7 @@ def fit(self, X, y, sample_weight=None, **params): accept_sparse="csc", dtype=[np.float64, np.float32], order="F", + writeable=True, copy=copy_X, ) X, y = self._validate_data( @@ -2512,6 +2515,7 @@ def fit(self, X, y): check_X_params = dict( dtype=[np.float64, np.float32], order="F", + writeable=True, copy=self.copy_X and self.fit_intercept, ) check_y_params = dict(ensure_2d=False, order="F") diff --git a/sklearn/linear_model/_least_angle.py b/sklearn/linear_model/_least_angle.py index fae9f523847ad..d6d6f0517a3c8 100644 --- a/sklearn/linear_model/_least_angle.py +++ b/sklearn/linear_model/_least_angle.py @@ -1180,7 +1180,9 @@ def fit(self, X, y, Xy=None): self : object Returns an instance of self. """ - X, y = self._validate_data(X, y, y_numeric=True, multi_output=True) + X, y = self._validate_data( + X, y, writeable=True, y_numeric=True, multi_output=True + ) alpha = getattr(self, "alpha", 0.0) if hasattr(self, "n_nonzero_coefs"): @@ -1721,7 +1723,7 @@ def fit(self, X, y, **params): """ _raise_for_params(params, self, "fit") - X, y = self._validate_data(X, y, y_numeric=True) + X, y = self._validate_data(X, y, writeable=True, y_numeric=True) X = as_float_array(X, copy=self.copy_X) y = as_float_array(y, copy=self.copy_X) @@ -2238,7 +2240,7 @@ def fit(self, X, y, copy_X=None): """ if copy_X is None: copy_X = self.copy_X - X, y = self._validate_data(X, y, y_numeric=True) + X, y = self._validate_data(X, y, writeable=True, y_numeric=True) X, y, Xmean, ymean, Xstd = _preprocess_data( X, y, fit_intercept=self.fit_intercept, copy=copy_X diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 85efe74b0cda5..8ec1dc024ea4f 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4769,11 +4769,7 @@ def check_inplace_ensure_writeable(name, estimator_orig): estimator.fit(X, y) if hasattr(estimator, "transform"): - Xt = estimator.transform(X) - - if hasattr(estimator, "inverse_transform"): - Xt.flags.writeable = False - estimator.inverse_transform(Xt) + estimator.transform(X) assert not X.flags.writeable assert_allclose(X, X_copy) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d2f35d6b4e7f9..d8c9c226a6cde 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1117,9 +1117,9 @@ def is_sparse(dtype): # error, in which case we make a copy. array_data.flags.writeable = True except ValueError: - array = array.copy() + array = array.copy(order="K") else: - array = array.copy() + array = array.copy(order="K") return array From 230d19cc0d6721c99329a9725c20293ba0b28734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Tue, 4 Jun 2024 13:47:27 +0200 Subject: [PATCH 08/13] fix sparse and select arrays with flags --- sklearn/utils/validation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d8c9c226a6cde..0ce2f060615dc 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1101,7 +1101,8 @@ def is_sparse(dtype): if writeable: array_data = array.data if sp.issparse(array) else array - if not array_data.flags.writeable: + copy_params = {"order": "K"} if not sp.issparse(array) else {} + if hasattr(array_data, "flags") and not array_data.flags.writeable: # This situation can only happen when copy=False, the array is read-only and # a writeable output is requested. This is an ambiguous setting so we chose # to always (except for one specific setting, see below) make a copy to @@ -1117,9 +1118,9 @@ def is_sparse(dtype): # error, in which case we make a copy. array_data.flags.writeable = True except ValueError: - array = array.copy(order="K") + array = array.copy(**copy_params) else: - array = array.copy(order="K") + array = array.copy(**copy_params) return array From 0fe8eaf73dc80120cac9fe7fb99bc14a4cf992fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Tue, 4 Jun 2024 15:41:53 +0200 Subject: [PATCH 09/13] rework mmap test using existing testing tool --- sklearn/utils/tests/test_validation.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index f146d2f7527b3..4f3dda44a0cb1 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2,7 +2,6 @@ import numbers import re -import tempfile import warnings from itertools import product from operator import itemgetter @@ -47,6 +46,7 @@ assert_allclose_dense_sparse, assert_array_equal, assert_no_warnings, + create_memmap_backed_data, ignore_warnings, skip_if_array_api_compat_not_configured, ) @@ -2155,21 +2155,17 @@ def test_check_array_writeable_mmap(): """ X = np.random.uniform(size=(10, 10)) - with tempfile.NamedTemporaryFile() as f: - mmap = np.memmap(f.name, dtype="float64", mode="w+", shape=(10, 10)) - mmap[:] = X[:] - - out = check_array(mmap, copy=False, writeable=True) - # mmap is already writeable, no copy is needed - assert np.may_share_memory(out, mmap) - assert out.flags.writeable - - mmap = np.memmap(f.name, dtype="float64", mode="r", shape=(10, 10)) + mmap = create_memmap_backed_data(X, mmap_mode="w+") + out = check_array(mmap, copy=False, writeable=True) + # mmap is already writeable, no copy is needed + assert np.may_share_memory(out, mmap) + assert out.flags.writeable - out = check_array(mmap, copy=False, writeable=True) - # mmap is read-only, a copy is made - assert not np.may_share_memory(out, mmap) - assert out.flags.writeable + mmap = create_memmap_backed_data(X, mmap_mode="r") + out = check_array(mmap, copy=False, writeable=True) + # mmap is read-only, a copy is made + assert not np.may_share_memory(out, mmap) + assert out.flags.writeable def test_check_array_writeable_df(): From 2f796e2da039b82edaf27357e05c7abfde0e903a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Wed, 5 Jun 2024 15:51:51 +0200 Subject: [PATCH 10/13] add change log entry --- doc/whats_new/v1.5.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 60b8dadc97373..1ebcb68560533 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -20,6 +20,13 @@ Version 1.5.1 **TODO** +Changes impacting many modules +------------------------------ + +- |Fix| Fixed a regression in the validation of the input data of all estimators where + an unexpected error was raised when passing a DataFrame backed by a read-only buffer. + :pr:`29018` by :user:`Jérémie du Boisberranger `. + Changelog --------- @@ -38,6 +45,14 @@ Changelog grids that have heterogeneous parameter values. :pr:`29078` by :user:`Loïc Estève `. +:mod:`sklearn.utils` +.................... + +- |API| :func:`utils.validation.check_array` has a new parameter, `writeable`, to + control the writeability of the output array. If set to True, the output array will + be guaranteed to be writeable and a copy will be made if the input array is read-only. + If left to None, no guarantee is made about the writeability of the output array. + :pr:`29018` by :user:`Jérémie du Boisberranger `. .. _changes_1_5: From 9ed87a20b6cf809e59946f1783e595d39f3161e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 14 Jun 2024 14:51:42 +0200 Subject: [PATCH 11/13] rename force_writeable and make it a bool --- sklearn/cluster/_affinity_propagation.py | 2 +- sklearn/cluster/_hdbscan/hdbscan.py | 4 +-- sklearn/cross_decomposition/_pls.py | 16 ++++++--- sklearn/decomposition/_factor_analysis.py | 4 ++- sklearn/decomposition/_incremental_pca.py | 4 +-- sklearn/decomposition/_pca.py | 2 +- sklearn/impute/_base.py | 2 +- sklearn/impute/_knn.py | 2 +- sklearn/linear_model/_base.py | 2 +- sklearn/linear_model/_bayes.py | 4 +-- sklearn/linear_model/_coordinate_descent.py | 8 ++--- sklearn/linear_model/_least_angle.py | 6 ++-- sklearn/linear_model/_ridge.py | 4 +-- sklearn/preprocessing/_data.py | 32 ++++++++++-------- sklearn/utils/tests/test_validation.py | 12 +++---- sklearn/utils/validation.py | 36 ++++++++++----------- 16 files changed, 77 insertions(+), 63 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index ef4d07c9e598d..e739dd487cfa8 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -504,7 +504,7 @@ def fit(self, X, y=None): Returns the instance itself. """ if self.affinity == "precomputed": - X = self._validate_data(X, copy=self.copy, writeable=True) + X = self._validate_data(X, copy=self.copy, force_writeable=True) self.affinity_matrix_ = X else: # self.affinity == "euclidean" X = self._validate_data(X, accept_sparse="csr") diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index f62640b833bcf..d20e745309fca 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -770,7 +770,7 @@ def fit(self, X, y=None): X, accept_sparse=["csr", "lil"], dtype=np.float64, - writeable=True, + force_writeable=True, ) else: # Only non-sparse, precomputed distance matrices are handled here @@ -779,7 +779,7 @@ def fit(self, X, y=None): # Perform data validation after removing infinite values (numpy.inf) # from the given distance matrix. X = self._validate_data( - X, force_all_finite=False, dtype=np.float64, writeable=True + X, force_all_finite=False, dtype=np.float64, force_writeable=True ) if np.isnan(X).any(): # TODO: Support np.nan in Cython implementation for precomputed diff --git a/sklearn/cross_decomposition/_pls.py b/sklearn/cross_decomposition/_pls.py index 2786fee1cac79..463fb2dbcc300 100644 --- a/sklearn/cross_decomposition/_pls.py +++ b/sklearn/cross_decomposition/_pls.py @@ -263,13 +263,17 @@ def fit(self, X, y=None, Y=None): check_consistent_length(X, y) X = self._validate_data( - X, dtype=np.float64, writeable=True, copy=self.copy, ensure_min_samples=2 + X, + dtype=np.float64, + force_writeable=True, + copy=self.copy, + ensure_min_samples=2, ) y = check_array( y, input_name="y", dtype=np.float64, - writeable=True, + force_writeable=True, copy=self.copy, ensure_2d=False, ) @@ -1061,13 +1065,17 @@ def fit(self, X, y=None, Y=None): y = _deprecate_Y_when_required(y, Y) check_consistent_length(X, y) X = self._validate_data( - X, dtype=np.float64, writeable=True, copy=self.copy, ensure_min_samples=2 + X, + dtype=np.float64, + force_writeable=True, + copy=self.copy, + ensure_min_samples=2, ) y = check_array( y, input_name="y", dtype=np.float64, - writeable=True, + force_writeable=True, copy=self.copy, ensure_2d=False, ) diff --git a/sklearn/decomposition/_factor_analysis.py b/sklearn/decomposition/_factor_analysis.py index 7ca2167aa6f72..dacf6c386c9fb 100644 --- a/sklearn/decomposition/_factor_analysis.py +++ b/sklearn/decomposition/_factor_analysis.py @@ -219,7 +219,9 @@ def fit(self, X, y=None): self : object FactorAnalysis class instance. """ - X = self._validate_data(X, copy=self.copy, dtype=np.float64, writeable=True) + X = self._validate_data( + X, copy=self.copy, dtype=np.float64, force_writeable=True + ) n_samples, n_features = X.shape n_components = self.n_components diff --git a/sklearn/decomposition/_incremental_pca.py b/sklearn/decomposition/_incremental_pca.py index 2d2a802e026cd..fb43ab57387c1 100644 --- a/sklearn/decomposition/_incremental_pca.py +++ b/sklearn/decomposition/_incremental_pca.py @@ -229,7 +229,7 @@ def fit(self, X, y=None): accept_sparse=["csr", "csc", "lil"], copy=self.copy, dtype=[np.float64, np.float32], - writeable=True, + force_writeable=True, ) n_samples, n_features = X.shape @@ -282,7 +282,7 @@ def partial_fit(self, X, y=None, check_input=True): X, copy=self.copy, dtype=[np.float64, np.float32], - writeable=True, + force_writeable=True, reset=first_pass, ) n_samples, n_features = X.shape diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index cece6343bd355..69349e1530b39 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -511,7 +511,7 @@ def _fit(self, X): X = self._validate_data( X, dtype=[xp.float64, xp.float32], - writeable=True, + force_writeable=True, accept_sparse=("csr", "csc"), ensure_2d=True, copy=False, diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 2e303c0c9c818..8942e07b745a0 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -334,7 +334,7 @@ def _validate_input(self, X, in_fit): reset=in_fit, accept_sparse="csc", dtype=dtype, - writeable=True if not in_fit else None, + force_writeable=True if not in_fit else None, force_all_finite=force_all_finite, copy=self.copy, ) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index 44b7676ffc973..8bc8ca014887f 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -270,7 +270,7 @@ def transform(self, X): X, accept_sparse=False, dtype=FLOAT_DTYPES, - writeable=True, + force_writeable=True, force_all_finite=force_all_finite, copy=self.copy, reset=False, diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index faaa4d6734967..56beba93f6f98 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -612,7 +612,7 @@ def fit(self, X, y, sample_weight=None): accept_sparse=accept_sparse, y_numeric=True, multi_output=True, - writeable=True, + force_writeable=True, ) has_sw = sample_weight is not None diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index a2ec20eda78da..c4356ac526a34 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -236,7 +236,7 @@ def fit(self, X, y, sample_weight=None): Returns the instance itself. """ X, y = self._validate_data( - X, y, dtype=[np.float64, np.float32], writeable=True, y_numeric=True + X, y, dtype=[np.float64, np.float32], force_writeable=True, y_numeric=True ) dtype = X.dtype @@ -625,7 +625,7 @@ def fit(self, X, y): X, y, dtype=[np.float64, np.float32], - writeable=True, + force_writeable=True, y_numeric=True, ensure_min_samples=2, ) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 2c46ab11efee3..9987304652b08 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -983,7 +983,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): accept_sparse="csc", order="F", dtype=[np.float64, np.float32], - writeable=True, + force_writeable=True, accept_large_sparse=False, copy=X_copied, multi_output=True, @@ -1612,7 +1612,7 @@ def fit(self, X, y, sample_weight=None, **params): check_X_params = dict( accept_sparse="csc", dtype=[np.float64, np.float32], - writeable=True, + force_writeable=True, copy=False, accept_large_sparse=False, ) @@ -1638,7 +1638,7 @@ def fit(self, X, y, sample_weight=None, **params): accept_sparse="csc", dtype=[np.float64, np.float32], order="F", - writeable=True, + force_writeable=True, copy=copy_X, ) X, y = self._validate_data( @@ -2515,7 +2515,7 @@ def fit(self, X, y): check_X_params = dict( dtype=[np.float64, np.float32], order="F", - writeable=True, + force_writeable=True, copy=self.copy_X and self.fit_intercept, ) check_y_params = dict(ensure_2d=False, order="F") diff --git a/sklearn/linear_model/_least_angle.py b/sklearn/linear_model/_least_angle.py index d6d6f0517a3c8..5f0715135d57b 100644 --- a/sklearn/linear_model/_least_angle.py +++ b/sklearn/linear_model/_least_angle.py @@ -1181,7 +1181,7 @@ def fit(self, X, y, Xy=None): Returns an instance of self. """ X, y = self._validate_data( - X, y, writeable=True, y_numeric=True, multi_output=True + X, y, force_writeable=True, y_numeric=True, multi_output=True ) alpha = getattr(self, "alpha", 0.0) @@ -1723,7 +1723,7 @@ def fit(self, X, y, **params): """ _raise_for_params(params, self, "fit") - X, y = self._validate_data(X, y, writeable=True, y_numeric=True) + X, y = self._validate_data(X, y, force_writeable=True, y_numeric=True) X = as_float_array(X, copy=self.copy_X) y = as_float_array(y, copy=self.copy_X) @@ -2240,7 +2240,7 @@ def fit(self, X, y, copy_X=None): """ if copy_X is None: copy_X = self.copy_X - X, y = self._validate_data(X, y, writeable=True, y_numeric=True) + X, y = self._validate_data(X, y, force_writeable=True, y_numeric=True) X, y, Xmean, ymean, Xstd = _preprocess_data( X, y, fit_intercept=self.fit_intercept, copy=copy_X diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index e8094a72040b1..7890875a76032 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1244,7 +1244,7 @@ def fit(self, X, y, sample_weight=None): y, accept_sparse=_accept_sparse, dtype=[xp.float64, xp.float32], - writeable=True, + force_writeable=True, multi_output=True, y_numeric=True, ) @@ -1294,7 +1294,7 @@ def _prepare_data(self, X, y, sample_weight, solver): accept_sparse=accept_sparse, multi_output=True, y_numeric=False, - writeable=True, + force_writeable=True, ) self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index d06b8786093ad..edf9dfdf80ea0 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -535,7 +535,7 @@ def transform(self, X): X, copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), - writeable=True, + force_writeable=True, force_all_finite="allow-nan", reset=False, ) @@ -567,7 +567,7 @@ def inverse_transform(self, X): X, copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), - writeable=True, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1048,7 +1048,7 @@ def transform(self, X, copy=None): accept_sparse="csr", copy=copy, dtype=FLOAT_DTYPES, - writeable=True, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1090,7 +1090,7 @@ def inverse_transform(self, X, copy=None): accept_sparse="csr", copy=copy, dtype=FLOAT_DTYPES, - writeable=True, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1295,7 +1295,7 @@ def transform(self, X): copy=self.copy, reset=False, dtype=_array_api.supported_float_dtypes(xp), - writeable=True, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1327,7 +1327,7 @@ def inverse_transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), - writeable=True, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1660,7 +1660,7 @@ def transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=FLOAT_DTYPES, - writeable=True, + force_writeable=True, reset=False, force_all_finite="allow-nan", ) @@ -1694,7 +1694,7 @@ def inverse_transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=FLOAT_DTYPES, - writeable=True, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1936,7 +1936,7 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False): copy=copy, estimator="the normalize function", dtype=_array_api.supported_float_dtypes(xp), - writeable=True, + force_writeable=True, ) if axis == 0: X = X.T @@ -2101,7 +2101,7 @@ def transform(self, X, copy=None): """ copy = copy if copy is not None else self.copy X = self._validate_data( - X, accept_sparse="csr", writeable=True, copy=copy, reset=False + X, accept_sparse="csr", force_writeable=True, copy=copy, reset=False ) return normalize(X, norm=self.norm, axis=1, copy=False) @@ -2157,7 +2157,7 @@ def binarize(X, *, threshold=0.0, copy=True): array([[0., 1., 0.], [1., 0., 0.]]) """ - X = check_array(X, accept_sparse=["csr", "csc"], writeable=True, copy=copy) + X = check_array(X, accept_sparse=["csr", "csc"], force_writeable=True, copy=copy) if sparse.issparse(X): if threshold < 0: raise ValueError("Cannot binarize a sparse matrix with threshold < 0") @@ -2299,7 +2299,11 @@ def transform(self, X, copy=None): # TODO: This should be refactored because binarize also calls # check_array X = self._validate_data( - X, accept_sparse=["csr", "csc"], writeable=True, copy=copy, reset=False + X, + accept_sparse=["csr", "csc"], + force_writeable=True, + copy=copy, + reset=False, ) return binarize(X, threshold=self.threshold, copy=False) @@ -2861,7 +2865,7 @@ def _check_inputs(self, X, in_fit, accept_sparse_negative=False, copy=False): accept_sparse="csc", copy=copy, dtype=FLOAT_DTYPES, - writeable=True if not in_fit else None, + force_writeable=True if not in_fit else None, force_all_finite="allow-nan", ) # we only accept positive sparse matrix when ignore_implicit_zeros is @@ -3500,7 +3504,7 @@ def _check_input(self, X, in_fit, check_positive=False, check_shape=False): X, ensure_2d=True, dtype=FLOAT_DTYPES, - writeable=True, + force_writeable=True, copy=self.copy, force_all_finite="allow-nan", reset=in_fit, diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 4f3dda44a0cb1..5bde51ae514d9 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2133,14 +2133,14 @@ def test_check_array_writeable_np(): """ X = np.random.uniform(size=(10, 10)) - out = check_array(X, copy=False, writeable=True) + out = check_array(X, copy=False, force_writeable=True) # X is already writeable, no copy is needed assert np.may_share_memory(out, X) assert out.flags.writeable X.flags.writeable = False - out = check_array(X, copy=False, writeable=True) + out = check_array(X, copy=False, force_writeable=True) # X is not writeable, a copy is made assert not np.may_share_memory(out, X) assert out.flags.writeable @@ -2156,13 +2156,13 @@ def test_check_array_writeable_mmap(): X = np.random.uniform(size=(10, 10)) mmap = create_memmap_backed_data(X, mmap_mode="w+") - out = check_array(mmap, copy=False, writeable=True) + out = check_array(mmap, copy=False, force_writeable=True) # mmap is already writeable, no copy is needed assert np.may_share_memory(out, mmap) assert out.flags.writeable mmap = create_memmap_backed_data(X, mmap_mode="r") - out = check_array(mmap, copy=False, writeable=True) + out = check_array(mmap, copy=False, force_writeable=True) # mmap is read-only, a copy is made assert not np.may_share_memory(out, mmap) assert out.flags.writeable @@ -2177,7 +2177,7 @@ def test_check_array_writeable_df(): X = np.random.uniform(size=(10, 10)) df = pd.DataFrame(X, copy=False) - out = check_array(df, copy=False, writeable=True) + out = check_array(df, copy=False, force_writeable=True) # df is backed by a writeable array, no copy is needed assert np.may_share_memory(out, df) assert out.flags.writeable @@ -2185,7 +2185,7 @@ def test_check_array_writeable_df(): X.flags.writeable = False df = pd.DataFrame(X, copy=False) - out = check_array(df, copy=False, writeable=True) + out = check_array(df, copy=False, force_writeable=True) # df is backed by a read-only array, a copy is made assert not np.may_share_memory(out, df) assert out.flags.writeable diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 0ce2f060615dc..1a0960d5026d7 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -722,8 +722,8 @@ def check_array( accept_large_sparse=True, dtype="numeric", order=None, - writeable=None, copy=False, + force_writeable=False, force_all_finite=True, ensure_2d=True, allow_nd=False, @@ -770,17 +770,17 @@ def check_array( the memory layout of the returned array is kept as close as possible to the original array. - writeable : True or None, default=None - Whether the returned array will be writeable. If True, the returned array is - guaranteed to be writeable, which may require a copy. If None, the writeability - of the input array is preserved. - - .. versionadded:: 1.6 - copy : bool, default=False Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion. + force_writeable : bool, default=False + Whether to force the output array to be writeable. If True, the returned array + is guaranteed to be writeable, which may require a copy. Otherwise the + writeability of the input array is preserved. + + .. versionadded:: 1.6 + force_all_finite : bool or 'allow-nan', default=True Whether to raise an error on np.inf, np.nan, pd.NA in array. The possibilities are: @@ -1099,7 +1099,7 @@ def is_sparse(dtype): % (n_features, array.shape, ensure_min_features, context) ) - if writeable: + if force_writeable: array_data = array.data if sp.issparse(array) else array copy_params = {"order": "K"} if not sp.issparse(array) else {} if hasattr(array_data, "flags") and not array_data.flags.writeable: @@ -1155,8 +1155,8 @@ def check_X_y( accept_large_sparse=True, dtype="numeric", order=None, - writeable=None, copy=False, + force_writeable=False, force_all_finite=True, ensure_2d=True, allow_nd=False, @@ -1207,17 +1207,17 @@ def check_X_y( Whether an array will be forced to be fortran or c-style. If `None`, then the input data's order is preserved when possible. - writeable : True or None, default=None - Whether the returned array will be writeable. If True, the returned array is - guaranteed to be writeable, which may require a copy. If None, the writeability - of the input array is preserved. - - .. versionadded:: 1.6 - copy : bool, default=False Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion. + force_writeable : bool, default=False + Whether to force the output array to be writeable. If True, the returned array + is guaranteed to be writeable, which may require a copy. Otherwise the + writeability of the input array is preserved. + + .. versionadded:: 1.6 + force_all_finite : bool or 'allow-nan', default=True Whether to raise an error on np.inf, np.nan, pd.NA in X. This parameter does not influence whether y can have np.inf, np.nan, pd.NA values. @@ -1300,8 +1300,8 @@ def check_X_y( accept_large_sparse=accept_large_sparse, dtype=dtype, order=order, - writeable=writeable, copy=copy, + force_writeable=force_writeable, force_all_finite=force_all_finite, ensure_2d=ensure_2d, allow_nd=allow_nd, From 5344d7b738eeed0f968c43d2cfcbb3deee673f06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 20 Jun 2024 11:59:03 +0200 Subject: [PATCH 12/13] fix what's new + add comments --- doc/whats_new/v1.5.rst | 6 +++--- sklearn/preprocessing/_data.py | 2 ++ sklearn/utils/validation.py | 8 ++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 430067fa1c2f6..9c1ff6aa6f53d 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -58,10 +58,10 @@ Changes impacting many modules :mod:`sklearn.utils` .................... -- |API| :func:`utils.validation.check_array` has a new parameter, `writeable`, to - control the writeability of the output array. If set to True, the output array will +- |API| :func:`utils.validation.check_array` has a new parameter, `force_writeable`, to + control the writeability of the output array. If set to `True`, the output array will be guaranteed to be writeable and a copy will be made if the input array is read-only. - If left to None, no guarantee is made about the writeability of the output array. + If set to `False`, no guarantee is made about the writeability of the output array. :pr:`29018` by :user:`Jérémie du Boisberranger `. .. _changes_1_5: diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index fed818a6f2b5e..7e7d8a8dd3c17 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -2859,6 +2859,8 @@ def _check_inputs(self, X, in_fit, accept_sparse_negative=False, copy=False): accept_sparse="csc", copy=copy, dtype=FLOAT_DTYPES, + # only set force_writeable for the validation at transform time because + # it's the only place where QuantileTransformer performs inplace operations. force_writeable=True if not in_fit else None, force_all_finite="allow-nan", ) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index edb09fe78068b..228fbe76a25e1 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1094,9 +1094,13 @@ def is_sparse(dtype): ) if force_writeable: - array_data = array.data if sp.issparse(array) else array + # By default, array.copy() creates a C-ordered copy. We set order=K to + # preserve the order of the array. copy_params = {"order": "K"} if not sp.issparse(array) else {} - if hasattr(array_data, "flags") and not array_data.flags.writeable: + + array_data = array.data if sp.issparse(array) else array + flags = getattr(array_data, "flags", None) + if not getattr(flags, "writeable", True): # This situation can only happen when copy=False, the array is read-only and # a writeable output is requested. This is an ambiguous setting so we chose # to always (except for one specific setting, see below) make a copy to From 145e36de1bb4f3998b17dfb8c066957e57640a6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 20 Jun 2024 12:02:00 +0200 Subject: [PATCH 13/13] cln merge --- doc/whats_new/v1.5.rst | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 9c1ff6aa6f53d..eae7195fea426 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -27,15 +27,12 @@ Changes impacting many modules an unexpected error was raised when passing a DataFrame backed by a read-only buffer. :pr:`29018` by :user:`Jérémie du Boisberranger `. -Changelog ---------- - -Changes impacting many modules ------------------------------- - - |Fix| Fixed a regression causing a dead-lock at import time in some settings. :pr:`29235` by :user:`Jérémie du Boisberranger `. +Changelog +--------- + :mod:`sklearn.metrics` ......................