From d2c03d2488a557bdd5c4d80559a17a225eb23ae2 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 28 Mar 2022 14:05:05 +0200 Subject: [PATCH 1/4] FIX allow copied parameters in __init__ for clone --- sklearn/base.py | 2 +- sklearn/tests/test_base.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 474380152028b..b3c5491b29a89 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -92,7 +92,7 @@ def clone(estimator, *, safe=True): for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - if param1 is not param2: + if param1 is not param2 and param1 != param2: raise RuntimeError( "Cannot clone object %s, as the constructor " "either does not set or modifies parameter %s" % (estimator, name) diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 88a065fe79657..65e0bc8436de6 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -2,6 +2,7 @@ # License: BSD 3 clause import re +import copy import numpy as np import scipy.sparse as sp import pytest @@ -76,8 +77,8 @@ class ModifyInitParams(BaseEstimator): Doesn't fulfill a is a """ - def __init__(self, a=np.array([0])): - self.a = a.copy() + def __init__(self, a=np.array([0, 1], dtype=int)): + self.a = a.astype(float) class Buggy(BaseEstimator): @@ -210,6 +211,20 @@ def test_clone_class_rather_than_instance(): clone(MyEstimator) +def test_clone_eq_on_param(): + # test that copying given parameters in constructor is a valid operation + # for `clone` + # regression test for #22857 + class Est(BaseEstimator): + def __init__(self, param=None): + self.param = copy.deepcopy(param) + + est1 = Est(param={"key": "value"}) + # this should not raise + est2 = clone(est1) + assert est1.param == est2.param + + def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator() From 0f4fe01130b68778736b90c02327654932f25be5 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 1 Apr 2022 11:57:38 +0200 Subject: [PATCH 2/4] workround for array comparison issue --- sklearn/base.py | 16 +++++++++++++--- sklearn/conftest.py | 9 +++++---- sklearn/utils/_pprint.py | 10 ++++++++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index b3c5491b29a89..ce1e22b667a9b 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -88,14 +88,24 @@ def clone(estimator, *, safe=True): new_object = klass(**new_object_params) params_set = new_object.get_params(deep=False) + init_param_error = ( + f"Cannot clone object {estimator}, as the constructor " + f"either does not set or modifies parameter {name}." + ) # quick sanity check of the parameters of the clone for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - if param1 is not param2 and param1 != param2: + if param1 is param2: + continue + try: + if param1 != param2: + raise RuntimeError(init_param_error) + except ValueError as e: raise RuntimeError( - "Cannot clone object %s, as the constructor " - "either does not set or modifies parameter %s" % (estimator, name) + init_param_error + + " The following error occurred while comparing old and new" + f" parameters: {str(e)}" ) return new_object diff --git a/sklearn/conftest.py b/sklearn/conftest.py index b6bde190ad3fb..821b9c2044dbc 100644 --- a/sklearn/conftest.py +++ b/sklearn/conftest.py @@ -19,10 +19,7 @@ from sklearn.datasets import fetch_kddcup99 from sklearn.datasets import fetch_olivetti_faces from sklearn.datasets import fetch_rcv1 - - -# This plugin is necessary to define the random seed fixture -pytest_plugins = ("sklearn.tests.random_seed",) +from sklearn.tests import random_seed if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION): @@ -249,3 +246,7 @@ def pytest_configure(config): matplotlib.use("agg") except ImportError: pass + + # Register global_random_seed plugin if it is not already registered + if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"): + config.pluginmanager.register(random_seed) diff --git a/sklearn/utils/_pprint.py b/sklearn/utils/_pprint.py index c96b1ce764c4a..7b5b5bbccdc18 100644 --- a/sklearn/utils/_pprint.py +++ b/sklearn/utils/_pprint.py @@ -98,8 +98,14 @@ def _changed_params(estimator): def has_changed(k, v): if k not in init_params: # happens if k is part of a **kwargs return True - if init_params[k] == inspect._empty: # k has no default value - return True + try: + if init_params[k] == inspect._empty: # k has no default value + return True + except ValueError: + # skip if "ValueError: The truth value of an array with more than + # one element is ambiguous. Use a.any() or a.all()" is raised + # during comparison. + pass # try to avoid calling repr on nested estimators if isinstance(v, BaseEstimator) and v.__class__ != init_params[k].__class__: return True From 11befeb8a06d5fcc3dc391680ba45f8220fe5bbe Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 1 Apr 2022 12:01:34 +0200 Subject: [PATCH 3/4] changelog --- doc/whats_new/v1.1.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 983f802f42388..f88971b426ff9 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -130,6 +130,12 @@ Changelog error message when setting invalid hyper-parameters with `set_params`. :pr:`21542` by :user:`Olivier Grisel `. +:mod:`sklearn.base` +................... +- |Fix| :func:`base.clone` now clones estimators which store a copy of their + constructor arguments as long as they are not arrays. :pr:`22973` by `Adrin + Jalali`_. + :mod:`sklearn.calibration` .......................... From 7256b04c37e8fc55722c0f72647024f561157c69 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 1 Apr 2022 14:30:16 +0200 Subject: [PATCH 4/4] move message to the right scope --- sklearn/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index ce1e22b667a9b..95e30ad87538c 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -88,14 +88,14 @@ def clone(estimator, *, safe=True): new_object = klass(**new_object_params) params_set = new_object.get_params(deep=False) - init_param_error = ( - f"Cannot clone object {estimator}, as the constructor " - f"either does not set or modifies parameter {name}." - ) # quick sanity check of the parameters of the clone for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] + init_param_error = ( + f"Cannot clone object {estimator}, as the constructor " + f"either does not set or modifies parameter {name}." + ) if param1 is param2: continue try: