diff --git a/sklearn/cluster/_bicluster.py b/sklearn/cluster/_bicluster.py index 6b1c824fc32ec..83a44a371b9ef 100644 --- a/sklearn/cluster/_bicluster.py +++ b/sklearn/cluster/_bicluster.py @@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod import numpy as np +import numbers from scipy.linalg import norm from scipy.sparse import dia_matrix, issparse @@ -13,6 +14,7 @@ from . import KMeans, MiniBatchKMeans from ..base import BaseEstimator, BiclusterMixin from ..utils import check_random_state +from ..utils import check_scalar from ..utils.extmath import make_nonnegative, randomized_svd, safe_sparse_dot @@ -102,7 +104,7 @@ def __init__( self.n_init = n_init self.random_state = random_state - def _check_parameters(self): + def _check_parameters(self, n_samples): legal_svd_methods = ("randomized", "arpack") if self.svd_method not in legal_svd_methods: raise ValueError( @@ -110,6 +112,7 @@ def _check_parameters(self): self.svd_method, legal_svd_methods ) ) + check_scalar(self.n_init, "n_init", target_type=numbers.Integral, min_val=1) def fit(self, X, y=None): """Create a biclustering for X. @@ -128,7 +131,7 @@ def fit(self, X, y=None): SpectralBiclustering instance. """ X = self._validate_data(X, accept_sparse="csr", dtype=np.float64) - self._check_parameters() + self._check_parameters(X.shape[0]) self._fit(X) return self @@ -328,6 +331,16 @@ def __init__( n_clusters, svd_method, n_svd_vecs, mini_batch, init, n_init, random_state ) + def _check_parameters(self, n_samples): + super()._check_parameters(n_samples) + check_scalar( + self.n_clusters, + "n_clusters", + target_type=numbers.Integral, + min_val=1, + max_val=n_samples, + ) + def _fit(self, X): normalized_data, row_diag, col_diag = _scale_normalize(X) n_sv = 1 + int(np.ceil(np.log2(self.n_clusters))) @@ -492,8 +505,8 @@ def __init__( self.n_components = n_components self.n_best = n_best - def _check_parameters(self): - super()._check_parameters() + def _check_parameters(self, n_samples): + super()._check_parameters(n_samples) legal_methods = ("bistochastic", "scale", "log") if self.method not in legal_methods: raise ValueError( @@ -502,36 +515,49 @@ def _check_parameters(self): ) ) try: - int(self.n_clusters) - except TypeError: + check_scalar( + self.n_clusters, + "n_clusters", + target_type=numbers.Integral, + min_val=1, + max_val=n_samples, + ) + except (ValueError, TypeError): try: - r, c = self.n_clusters - int(r) - int(c) + n_row_clusters, n_column_clusters = self.n_clusters + check_scalar( + n_row_clusters, + "n_row_clusters", + target_type=numbers.Integral, + min_val=1, + max_val=n_samples, + ) + check_scalar( + n_column_clusters, + "n_column_clusters", + target_type=numbers.Integral, + min_val=1, + max_val=n_samples, + ) except (ValueError, TypeError) as e: raise ValueError( "Incorrect parameter n_clusters has value:" - " {}. It should either be a single integer" + f" {self.n_clusters}. It should either be a single integer" " or an iterable with two integers:" " (n_row_clusters, n_column_clusters)" + " And the values are should be in the" + " range: (1, n_samples)" ) from e - if self.n_components < 1: - raise ValueError( - "Parameter n_components must be greater than 0," - " but its value is {}".format(self.n_components) - ) - if self.n_best < 1: - raise ValueError( - "Parameter n_best must be greater than 0, but its value is {}".format( - self.n_best - ) - ) - if self.n_best > self.n_components: - raise ValueError( - "n_best cannot be larger than n_components, but {} > {}".format( - self.n_best, self.n_components - ) - ) + check_scalar( + self.n_components, "n_components", target_type=numbers.Integral, min_val=1 + ) + check_scalar( + self.n_best, + "n_best", + target_type=numbers.Integral, + min_val=1, + max_val=self.n_components, + ) def _fit(self, X): n_sv = self.n_components diff --git a/sklearn/cluster/tests/test_bicluster.py b/sklearn/cluster/tests/test_bicluster.py index ba6d91a537143..184fe3891804e 100644 --- a/sklearn/cluster/tests/test_bicluster.py +++ b/sklearn/cluster/tests/test_bicluster.py @@ -208,30 +208,65 @@ def test_perfect_checkerboard(): @pytest.mark.parametrize( - "args", + "params, type_err, err_msg", [ - {"n_clusters": (3, 3, 3)}, - {"n_clusters": "abc"}, - {"n_clusters": (3, "abc")}, - {"method": "unknown"}, - {"n_components": 0}, - {"n_best": 0}, - {"svd_method": "unknown"}, - {"n_components": 3, "n_best": 4}, + ({"n_init": 0}, ValueError, "n_init == 0, must be >= 1."), + ({"n_init": 1.5}, TypeError, "n_init must be an instance of"), + ( + {"n_clusters": "abc"}, + TypeError, + "n_clusters must be an instance of", + ), + ({"svd_method": "unknown"}, ValueError, "Unknown SVD method: 'unknown'"), ], ) -def test_errors(args): +def test_spectralcoclustering_parameter_validation(params, type_err, err_msg): + """Check parameters validation in `SpectralBiClustering`""" data = np.arange(25).reshape((5, 5)) - - model = SpectralBiclustering(**args) - with pytest.raises(ValueError): + model = SpectralCoclustering(**params) + with pytest.raises(type_err, match=err_msg): model.fit(data) -def test_wrong_shape(): - model = SpectralBiclustering() - data = np.arange(27).reshape((3, 3, 3)) - with pytest.raises(ValueError): +@pytest.mark.parametrize( + "params, type_err, err_msg", + [ + ({"n_init": 0}, ValueError, "n_init == 0, must be >= 1."), + ({"n_init": 1.5}, TypeError, "n_init must be an instance of"), + ( + {"n_clusters": (3, 3, 3)}, + ValueError, + r"Incorrect parameter n_clusters has value: \(3, 3, 3\)", + ), + ( + {"n_clusters": "abc"}, + ValueError, + "Incorrect parameter n_clusters has value: abc", + ), + ( + {"n_clusters": (3, "abc")}, + ValueError, + r"Incorrect parameter n_clusters has value: \(3, 'abc'\)", + ), + ( + {"n_clusters": ("abc", 3)}, + ValueError, + r"Incorrect parameter n_clusters has value: \('abc', 3\)", + ), + ({"method": "unknown"}, ValueError, "Unknown method: 'unknown'"), + ({"n_components": 0}, ValueError, "n_components == 0, must be >= 1."), + ({"n_components": 1.5}, TypeError, "n_components must be an instance of"), + ({"n_components": 3, "n_best": 4}, ValueError, "n_best == 4, must be <= 3."), + ({"n_best": 0}, ValueError, "n_best == 0, must be >= 1."), + ({"n_best": 1.5}, TypeError, "n_best must be an instance of"), + ({"svd_method": "unknown"}, ValueError, "Unknown SVD method: 'unknown'"), + ], +) +def test_spectralbiclustering_parameter_validation(params, type_err, err_msg): + """Check parameters validation in `SpectralBiClustering`""" + data = np.arange(25).reshape((5, 5)) + model = SpectralBiclustering(**params) + with pytest.raises(type_err, match=err_msg): model.fit(data)