Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ Changelog
error message when setting invalid hyper-parameters with `set_params`.
:pr:`21542` by :user:`Olivier Grisel <ogrisel>`.

:mod:`sklearn.base`
...................
- |Fix| :func:`base.clone` now clones estimators which store a copy of their
constructor arguments as long as they are not arrays. :pr:`22973` by `Adrin
Jalali`_.

:mod:`sklearn.calibration`
..........................

Expand Down
16 changes: 13 additions & 3 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,20 @@ def clone(estimator, *, safe=True):
for name in new_object_params:
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
init_param_error = (
f"Cannot clone object {estimator}, as the constructor "
f"either does not set or modifies parameter {name}."
)
if param1 is param2:
continue
try:
if param1 != param2:
raise RuntimeError(init_param_error)
except ValueError as e:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
init_param_error
+ " The following error occurred while comparing old and new"
f" parameters: {str(e)}"
)
return new_object

Expand Down
9 changes: 5 additions & 4 deletions sklearn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from sklearn.datasets import fetch_kddcup99
from sklearn.datasets import fetch_olivetti_faces
from sklearn.datasets import fetch_rcv1


# This plugin is necessary to define the random seed fixture
pytest_plugins = ("sklearn.tests.random_seed",)
from sklearn.tests import random_seed


if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION):
Expand Down Expand Up @@ -249,3 +246,7 @@ def pytest_configure(config):
matplotlib.use("agg")
except ImportError:
pass

# Register global_random_seed plugin if it is not already registered
if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a merge error? This change is already on main.

config.pluginmanager.register(random_seed)
19 changes: 17 additions & 2 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License: BSD 3 clause

import re
import copy
import numpy as np
import scipy.sparse as sp
import pytest
Expand Down Expand Up @@ -76,8 +77,8 @@ class ModifyInitParams(BaseEstimator):
Doesn't fulfill a is a
"""

def __init__(self, a=np.array([0])):
self.a = a.copy()
Comment on lines -79 to -80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the git history: 680ab51 and the original issue: #5540, the proposed feature in this PR was deprecated and removed in a older version of scikit-learn.

def __init__(self, a=np.array([0, 1], dtype=int)):
self.a = a.astype(float)


class Buggy(BaseEstimator):
Expand Down Expand Up @@ -210,6 +211,20 @@ def test_clone_class_rather_than_instance():
clone(MyEstimator)


def test_clone_eq_on_param():
# test that copying given parameters in constructor is a valid operation
# for `clone`
# regression test for #22857
class Est(BaseEstimator):
def __init__(self, param=None):
self.param = copy.deepcopy(param)

est1 = Est(param={"key": "value"})
# this should not raise
est2 = clone(est1)
assert est1.param == est2.param


def test_repr():
# Smoke test the repr of the base estimator.
my_estimator = MyEstimator()
Expand Down
10 changes: 8 additions & 2 deletions sklearn/utils/_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ def _changed_params(estimator):
def has_changed(k, v):
if k not in init_params: # happens if k is part of a **kwargs
return True
if init_params[k] == inspect._empty: # k has no default value
return True
try:
if init_params[k] == inspect._empty: # k has no default value
return True
except ValueError:
# skip if "ValueError: The truth value of an array with more than
# one element is ambiguous. Use a.any() or a.all()" is raised
# during comparison.
pass
# try to avoid calling repr on nested estimators
if isinstance(v, BaseEstimator) and v.__class__ != init_params[k].__class__:
return True
Expand Down