Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MNT Param validation: Make it possible to mark a constraint as hidden #23558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..utils import check_random_state
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils.validation import _is_arraylike_not_scalar
from ..utils._param_validation import Hidden
from ..utils._param_validation import Interval
from ..utils._param_validation import StrOptions
from ..utils._param_validation import validate_params
Expand Down Expand Up @@ -273,7 +274,8 @@ def _tolerance(X, tol):
"sample_weight": ["array-like", None],
"init": [StrOptions({"k-means++", "random"}), callable, "array-like"],
"n_init": [
StrOptions({"auto", "warn"}, internal={"warn"}),
StrOptions({"auto"}),
Hidden(StrOptions({"warn"})),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is StrOptions({"warn"}, hidden=True), but then we would need _InstancesOf "public", to do InstancesOf(Criterion, hidden=True). Also all the other constraints will have to be "public" to be configured to hide.

With that in mind, I'm okay with the PR as is.

Interval(Integral, 1, None, closed="left"),
],
"max_iter": [Interval(Integral, 1, None, closed="left")],
Expand Down Expand Up @@ -834,7 +836,8 @@ class _BaseKMeans(
"n_clusters": [Interval(Integral, 1, None, closed="left")],
"init": [StrOptions({"k-means++", "random"}), callable, "array-like"],
"n_init": [
StrOptions({"auto", "warn"}, internal={"warn"}),
StrOptions({"auto"}),
Hidden(StrOptions({"warn"})),
Interval(Integral, 1, None, closed="left"),
],
"max_iter": [Interval(Integral, 1, None, closed="left")],
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from joblib import Parallel
from numbers import Integral

from ..utils._param_validation import StrOptions
from ..utils._param_validation import StrOptions, Hidden
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin, MultiOutputMixin
from ..preprocessing._data import _is_constant_feature
from ..utils import check_array
Expand Down Expand Up @@ -640,7 +640,7 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel):

_parameter_constraints = {
"fit_intercept": [bool],
"normalize": [StrOptions({"deprecated"}, internal={"deprecated"}), bool],
"normalize": [Hidden(StrOptions({"deprecated"})), bool],
"copy_X": [bool],
"n_jobs": [None, Integral],
"positive": [bool],
Expand Down
54 changes: 32 additions & 22 deletions sklearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
else:
# No constraint is satisfied, raise with an informative message.

# Ignore constraints that only contains internal options that we don't want
# to expose in the error message
constraints = [constraint for constraint in constraints if str(constraint)]
# Ignore constraints that we don't want to expose in the error message,
# i.e. options that are for internal purpose or not officially supported.
constraints = [
constraint for constraint in constraints if not constraint.hidden
]

if len(constraints) == 1:
constraints_str = f"{constraints[0]}"
Expand Down Expand Up @@ -98,6 +100,10 @@ def make_constraint(constraint):
return _InstancesOf(constraint)
if isinstance(constraint, (Interval, StrOptions)):
return constraint
if isinstance(constraint, Hidden):
constraint = make_constraint(constraint.constraint)
constraint.hidden = True
return constraint
raise ValueError(f"Unknown constraint type: {constraint}")


Expand Down Expand Up @@ -148,6 +154,9 @@ def wrapper(*args, **kwargs):
class _Constraint(ABC):
"""Base class for the constraint objects."""

def __init__(self):
self.hidden = False

@abstractmethod
def is_satisfied_by(self, val):
"""Whether or not a value satisfies the constraint.
Expand Down Expand Up @@ -178,6 +187,7 @@ class _InstancesOf(_Constraint):
"""

def __init__(self, type):
super().__init__()
self.type = type

def _type_name(self, t):
Expand Down Expand Up @@ -221,25 +231,15 @@ class StrOptions(_Constraint):
A subset of the `options` to mark as deprecated in the repr of the constraint.
"""

@validate_params(
{"options": [set], "deprecated": [set, None], "internal": [set, None]}
)
def __init__(self, options, deprecated=None, internal=None):
@validate_params({"options": [set], "deprecated": [set, None]})
def __init__(self, options, deprecated=None):
super().__init__()
self.options = options
self.deprecated = deprecated or set()
self.internal = internal or set()

if self.deprecated - self.options:
raise ValueError("The deprecated options must be a subset of the options.")

if self.internal - self.options:
raise ValueError("The internal options must be a subset of the options.")

if self.deprecated & self.internal:
raise ValueError(
"The deprecated and internal parameters should not overlap."
)

def is_satisfied_by(self, val):
return isinstance(val, str) and val in self.options

Expand All @@ -251,13 +251,8 @@ def _mark_if_deprecated(self, option):
return option_str

def __str__(self):
visible_options = [o for o in self.options if o not in self.internal]

if not visible_options:
return ""

options_str = (
f"{', '.join([self._mark_if_deprecated(o) for o in visible_options])}"
f"{', '.join([self._mark_if_deprecated(o) for o in self.options])}"
)
return f"a str among {{{options_str}}}"

Expand Down Expand Up @@ -304,6 +299,7 @@ class Interval(_Constraint):
}
)
def __init__(self, type, left, right, *, closed):
super().__init__()
self.type = type
self.left = left
self.right = right
Expand Down Expand Up @@ -405,6 +401,7 @@ class _RandomStates(_Constraint):
"""

def __init__(self):
super().__init__()
self._constraints = [
Interval(Integral, 0, 2**32 - 1, closed="both"),
_InstancesOf(np.random.RandomState),
Expand All @@ -421,6 +418,19 @@ def __str__(self):
)


class Hidden:
"""Class encapsulating a constraint not meant to be exposed to the user.

Parameters
----------
constraint : str or _Constraint instance
The constraint to be used internally.
"""

def __init__(self, constraint):
self.constraint = constraint


def generate_invalid_param_val(constraint, constraints=None):
"""Return a value that does not satisfy the constraint.

Expand Down
64 changes: 30 additions & 34 deletions sklearn/utils/tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sklearn.base import BaseEstimator
from sklearn.utils import deprecated
from sklearn.utils._param_validation import Hidden
from sklearn.utils._param_validation import Interval
from sklearn.utils._param_validation import StrOptions
from sklearn.utils._param_validation import _ArrayLikes
Expand Down Expand Up @@ -129,13 +130,12 @@ def test_interval_errors(params, error, match):

def test_stroptions():
"""Sanity check for the StrOptions constraint"""
options = StrOptions({"a", "b", "c"}, deprecated={"c"}, internal={"b"})
options = StrOptions({"a", "b", "c"}, deprecated={"c"})
assert options.is_satisfied_by("a")
assert options.is_satisfied_by("c")
assert not options.is_satisfied_by("d")

assert "'c' (deprecated)" in str(options)
assert "'b'" not in str(options)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -420,51 +420,47 @@ def test_validate_params_estimator():
est.fit()


def test_internal_values_not_exposed():
"""Check that valid values that are for internal purpose, e.g. "warn" or
"deprecated" are not exposed in the error message
"""
def test_stroptions_deprecated_subset():
"""Check that the deprecated parameter must be a subset of options."""
with pytest.raises(ValueError, match="deprecated options must be a subset"):
StrOptions({"a", "b", "c"}, deprecated={"a", "d"})

@validate_params({"param": [StrOptions({"auto", "warn"}, internal={"warn"})]})

def test_hidden_constraint():
"""Check that internal constraints are not exposed in the error message."""

@validate_params({"param": [Hidden(list), dict]})
def f(param):
pass

# list and dict are valid params
f({"a": 1, "b": 2, "c": 3})
f([1, 2, 3])

with pytest.raises(ValueError, match="The 'param' parameter") as exc_info:
f(param="bad")

# the list option is not exposed in the error message
err_msg = str(exc_info.value)
assert "a str among" in err_msg
assert "auto" in err_msg
assert "warn" not in err_msg
assert "an instance of 'dict'" in err_msg
assert "an instance of 'list'" not in err_msg


# no error
f(param="warn")
def test_hidden_stroptions():
"""Check that we can have 2 StrOptions constraints, one being hidden."""

@validate_params({"param": [int, StrOptions({"warn"}, internal={"warn"})]})
def g(param):
@validate_params({"param": [StrOptions({"auto"}), Hidden(StrOptions({"warn"}))]})
def f(param):
pass

# "auto" and "warn" are valid params
f("auto")
f("warn")

with pytest.raises(ValueError, match="The 'param' parameter") as exc_info:
g(param="bad")
f(param="bad")

# the "warn" option is not exposed in the error message
err_msg = str(exc_info.value)
assert "a str among" not in err_msg
assert "auto" in err_msg
assert "warn" not in err_msg

# no error
g(param="warn")


def test_stroptions_deprecated_internal_overlap():
"""Check that the internal and deprecated parameters are not allowed to overlap."""
with pytest.raises(ValueError, match="should not overlap"):
StrOptions({"a", "b", "c"}, deprecated={"b", "c"}, internal={"a", "b"})


def test_stroptions_deprecated_internal_subset():
"""Check that the deprecated and internal parameters must be subsets of options."""
with pytest.raises(ValueError, match="deprecated options must be a subset"):
StrOptions({"a", "b", "c"}, deprecated={"a", "d"})

with pytest.raises(ValueError, match="internal options must be a subset"):
StrOptions({"a", "b", "c"}, internal={"a", "d"})