Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MNT use check_scalar to validate scalar in AffinityPropagation #20723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@

# License: BSD 3 clause

import numpy as np
import numbers
import warnings

import numpy as np

from ..exceptions import ConvergenceWarning
from ..base import BaseEstimator, ClusterMixin
from ..utils import as_float_array, check_random_state
from ..utils import check_scalar
from ..utils.deprecation import deprecated
from ..utils.validation import check_is_fitted
from ..metrics import euclidean_distances
Expand Down Expand Up @@ -132,8 +135,6 @@ def affinity_propagation(

if preference is None:
preference = np.median(S)
if damping < 0.5 or damping >= 1:
raise ValueError("damping must be >= 0.5 and < 1")

preference = np.array(preference)

Expand Down Expand Up @@ -456,6 +457,22 @@ def fit(self, X, y=None):
% str(self.affinity)
)

check_scalar(
self.damping,
"damping",
target_type=numbers.Real,
min_val=0.5,
max_val=1,
closed="right",
)
check_scalar(self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1)
check_scalar(
self.convergence_iter,
"convergence_iter",
target_type=numbers.Integral,
min_val=1,
)

(
self.cluster_centers_indices_,
self.labels_,
Expand Down
38 changes: 28 additions & 10 deletions sklearn/cluster/tests/test_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,35 @@ def test_affinity_propagation():
)
assert_array_equal(labels, labels_no_copy)

# Test input validation
with pytest.raises(ValueError):

def test_affinity_propagation_affinity_shape():
"""Check the shape of the affinity matrix when using `affinity_propagation."""
S = -euclidean_distances(X, squared=True)
err_msg = "S must be a square array"
with pytest.raises(ValueError, match=err_msg):
affinity_propagation(S[:, :-1])
with pytest.raises(ValueError):
affinity_propagation(S, damping=0)
af = AffinityPropagation(affinity="unknown", random_state=78)
with pytest.raises(ValueError):
af.fit(X)
af_2 = AffinityPropagation(affinity="precomputed", random_state=21)
with pytest.raises(TypeError):
af_2.fit(csr_matrix((3, 3)))


@pytest.mark.parametrize(
"input, params, err_type, err_msg",
[
(X, {"damping": 0}, ValueError, "damping == 0, must be >= 0.5"),
(X, {"damping": 2}, ValueError, "damping == 2, must be < 1"),
(X, {"max_iter": 0}, ValueError, "max_iter == 0, must be >= 1."),
(X, {"convergence_iter": 0}, ValueError, "convergence_iter == 0, must be >= 1"),
(X, {"affinity": "unknown"}, ValueError, "Affinity must be"),
(
csr_matrix((3, 3)),
{"affinity": "precomputed"},
TypeError,
"A sparse matrix was passed, but dense data is required",
),
],
)
def test_affinity_propagation_params_validation(input, params, err_type, err_msg):
"""Check the parameters validation in `AffinityPropagation`."""
with pytest.raises(err_type, match=err_msg):
AffinityPropagation(**params).fit(input)


def test_affinity_propagation_predict():
Expand Down
2 changes: 1 addition & 1 deletion sklearn/neighbors/tests/test_nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_params_validation():
)
with pytest.raises(ValueError, match=re.escape(msg)):
NCA(init=1).fit(X, y)
with pytest.raises(ValueError, match="`max_iter`= -1, must be >= 1."):
with pytest.raises(ValueError, match="max_iter == -1, must be >= 1."):
NCA(max_iter=-1).fit(X, y)
init = rng.rand(5, 3)
msg = (
Expand Down
72 changes: 59 additions & 13 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for input validation functions"""

import numbers
import warnings
import os
import re
Expand Down Expand Up @@ -1004,43 +1005,88 @@ def __init__(self):
_num_samples(TestNoLenWeirdShape())


@pytest.mark.parametrize(
"x, target_type, min_val, max_val", [(3, int, 2, 5), (2.5, float, 2, 5)]
)
def test_check_scalar_valid(x, target_type, min_val, max_val):
@pytest.mark.parametrize("x", [2, 3, 2.5, 5])
def test_check_scalar_valid(x):
"""Test that check_scalar returns no error/warning if valid inputs are
provided"""
with pytest.warns(None) as record:
check_scalar(
x, "test_name", target_type=target_type, min_val=min_val, max_val=max_val
scalar = check_scalar(
x,
"test_name",
target_type=numbers.Real,
min_val=2,
max_val=5,
closed="neither",
)
assert len(record) == 0
assert scalar == x


@pytest.mark.parametrize(
"x, target_name, target_type, min_val, max_val, err_msg",
"x, target_name, target_type, min_val, max_val, closed, err_msg",
[
(
1,
"test_name1",
float,
2,
4,
"neither",
TypeError(
"`test_name1` must be an instance of "
"<class 'float'>, not <class 'int'>."
"test_name1 must be an instance of <class 'float'>, not <class 'int'>."
),
),
(1, "test_name2", int, 2, 4, ValueError("`test_name2`= 1, must be >= 2.")),
(5, "test_name3", int, 2, 4, ValueError("`test_name3`= 5, must be <= 4.")),
(
1,
"test_name2",
int,
2,
4,
"neither",
ValueError("test_name2 == 1, must be >= 2."),
),
(
5,
"test_name3",
int,
2,
4,
"neither",
ValueError("test_name3 == 5, must be <= 4."),
),
(
2,
"test_name4",
int,
2,
4,
"left",
ValueError("test_name4 == 2, must be > 2."),
),
(
4,
"test_name5",
int,
2,
4,
"right",
ValueError("test_name5 == 4, must be < 4."),
),
],
)
def test_check_scalar_invalid(x, target_name, target_type, min_val, max_val, err_msg):
def test_check_scalar_invalid(
x, target_name, target_type, min_val, max_val, closed, err_msg
):
"""Test that check_scalar returns the right error if a wrong input is
given"""
with pytest.raises(Exception) as raised_error:
check_scalar(
x, target_name, target_type=target_type, min_val=min_val, max_val=max_val
x,
target_name,
target_type=target_type,
min_val=min_val,
max_val=max_val,
closed=closed,
)
assert str(raised_error.value) == str(err_msg)
assert type(raised_error.value) == type(err_msg)
Expand Down
48 changes: 39 additions & 9 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import wraps
import warnings
import numbers
import operator

import numpy as np
import scipy.sparse as sp
Expand Down Expand Up @@ -1231,7 +1232,15 @@ def check_non_negative(X, whom):
raise ValueError("Negative values in data passed to %s" % whom)


def check_scalar(x, name, target_type, *, min_val=None, max_val=None):
def check_scalar(
x,
name,
target_type,
*,
min_val=None,
max_val=None,
closed="neither",
):
"""Validate scalar parameters type and value.

Parameters
Expand All @@ -1249,12 +1258,21 @@ def check_scalar(x, name, target_type, *, min_val=None, max_val=None):
The minimum valid value the parameter can take. If None (default) it
is implied that the parameter does not have a lower bound.

max_val : float or int, default=None
max_val : float or int, default=False
The maximum valid value the parameter can take. If None (default) it
is implied that the parameter does not have an upper bound.

Raises
closed : {"left", "right", "both", "neither"}, default="neither"
Whether the interval is closed on the left-side, right-side, both or
neither.

Returns
-------
x : numbers.Number
The validated number.

Raises
------
TypeError
If the parameter's type does not match the desired type.

Expand All @@ -1263,15 +1281,27 @@ def check_scalar(x, name, target_type, *, min_val=None, max_val=None):
"""

if not isinstance(x, target_type):
raise TypeError(
"`{}` must be an instance of {}, not {}.".format(name, target_type, type(x))
raise TypeError(f"{name} must be an instance of {target_type}, not {type(x)}.")

expected_closed = {"left", "right", "both", "neither"}
if closed not in expected_closed:
raise ValueError(f"Unknown value for `closed`: {closed}")

comparison_operator = operator.le if closed in ("left", "both") else operator.lt
if min_val is not None and comparison_operator(x, min_val):
raise ValueError(
f"{name} == {x}, must be"
f" {'>' if closed in ('left', 'both') else '>='} {min_val}."
)

if min_val is not None and x < min_val:
raise ValueError("`{}`= {}, must be >= {}.".format(name, x, min_val))
comparison_operator = operator.ge if closed in ("right", "both") else operator.gt
if max_val is not None and comparison_operator(x, max_val):
raise ValueError(
f"{name} == {x}, must be"
f" {'<' if closed in ('right', 'both') else '<='} {max_val}."
)

if max_val is not None and x > max_val:
raise ValueError("`{}`= {}, must be <= {}.".format(name, x, max_val))
return x


def _check_psd_eigenvalues(lambdas, enable_warnings=False):
Expand Down