diff --git a/sklearn/neighbors/_binary_tree.pxi.tp b/sklearn/neighbors/_binary_tree.pxi.tp index dd77bcbdfb3d6..5cf7b0ad99990 100644 --- a/sklearn/neighbors/_binary_tree.pxi.tp +++ b/sklearn/neighbors/_binary_tree.pxi.tp @@ -1482,7 +1482,7 @@ cdef class BinaryTree{{name_suffix}}: raise ValueError("query data dimension must " "match training data dimension") Xarr_np = X.reshape((-1, n_features)) - cdef {{INPUT_DTYPE_t}}[:, ::1] Xarr = Xarr_np + cdef const {{INPUT_DTYPE_t}}[:, ::1] Xarr = Xarr_np log_density_arr = np.zeros(Xarr.shape[0], dtype={{INPUT_DTYPE}}) cdef {{INPUT_DTYPE_t}}[::1] log_density = log_density_arr diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index a1ed6b21a2219..20b3fdc600a8c 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -3,11 +3,13 @@ import warnings from copy import copy from itertools import chain +from unittest import SkipTest import numpy as np import pytest from sklearn import config_context +from sklearn.externals._packaging.version import parse as parse_version from sklearn.utils import ( _approximate_mode, _determine_key_type, @@ -461,6 +463,12 @@ def test_safe_indexing_pandas_no_settingwithcopy_warning(): # DataFrame -> ensure it doesn't raise a warning if modified pd = pytest.importorskip("pandas") + pd_version = parse_version(pd.__version__) + pd_base_version = parse_version(pd_version.base_version) + + if pd_base_version >= parse_version("3"): + raise SkipTest("SettingWithCopyWarning has been removed in pandas 3.0.0.dev") + X = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) subset = _safe_indexing(X, [0, 1], axis=0) if hasattr(pd.errors, "SettingWithCopyWarning"): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 7bbb7b753a1ba..a5c84ecf6411c 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1084,6 +1084,18 @@ def is_sparse(dtype): % (n_features, array.shape, ensure_min_features, context) ) + # With an input pandas dataframe or series, we know we can always make the + # resulting array writeable: + # - if copy=True, we have already made a copy so it is fine to make the + # array writeable + # - if copy=False, the caller is telling us explicitly that we can do + # in-place modifications + # See https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html#read-only-numpy-arrays + # for more details about pandas copy-on-write mechanism, that is enabled by + # default in pandas 3.0.0.dev. + if _is_pandas_df_or_series(array_orig) and hasattr(array, "flags"): + array.flags.writeable = True + return array @@ -2140,6 +2152,15 @@ def _check_method_params(X, params, indices=None): return method_params_validated +def _is_pandas_df_or_series(X): + """Return True if the X is a pandas dataframe or series.""" + try: + pd = sys.modules["pandas"] + except KeyError: + return False + return isinstance(X, (pd.DataFrame, pd.Series)) + + def _is_pandas_df(X): """Return True if the X is a pandas dataframe.""" try: