-
-
Notifications
You must be signed in to change notification settings - Fork 26k
MNT use check_scalar in SpectralBiClustering and SpectralCoClustering #20817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sklearn/cluster/_bicluster.py
Outdated
"n_clusters": { "target_type": numbers.Integral,"min_val": 1, "max_val": n_samples}, | ||
"n_init": { "target_type": numbers.Integral,"min_val": 1 }, | ||
} | ||
for scalar_name in scalars_checks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can make 2 calls and not use a loop.
sklearn/cluster/_bicluster.py
Outdated
legal_svd_methods = ("randomized", "arpack") | ||
if self.svd_method not in legal_svd_methods: | ||
raise ValueError( | ||
"Unknown SVD method: '{0}'. svd_method must be one of {1}.".format( | ||
self.svd_method, legal_svd_methods | ||
) | ||
) | ||
scalars_checks = { | ||
"n_clusters": { "target_type": numbers.Integral,"min_val": 1, "max_val": n_samples}, | ||
"n_init": { "target_type": numbers.Integral,"min_val": 1 }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some PEP8 issue here
sklearn/cluster/_bicluster.py
Outdated
"n_components": { "target_type": numbers.Integral,"min_val": 1}, | ||
"n_best": { "target_type": numbers.Integral,"min_val": 1, "max_val": self.n_components }, | ||
} | ||
for scalar_name in scalars_checks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
Thanks for the reviews. I made the changes. If I'm still missing something, please let me know. |
The CIs are failing. Can you look at the reported error. You should be able to reproduce the error locally by running the associated tests of the scikit-learn estimator that you modified. |
I am posting the patch that would be required to fix most probably the CI with some additional changes that I would have asked later on with a good review: diff --git a/sklearn/cluster/_bicluster.py b/sklearn/cluster/_bicluster.py
index 849afd6cf5..5bb8b95cc6 100644
--- a/sklearn/cluster/_bicluster.py
+++ b/sklearn/cluster/_bicluster.py
@@ -106,7 +106,7 @@ class BaseSpectral(BiclusterMixin, BaseEstimator, metaclass=ABCMeta):
self.n_init = n_init
self.random_state = random_state
- def _check_parameters(self, n_samples):
+ def _check_parameters(self, n_samples, n_features):
legal_svd_methods = ("randomized", "arpack")
if self.svd_method not in legal_svd_methods:
raise ValueError(
@@ -114,8 +114,13 @@ class BaseSpectral(BiclusterMixin, BaseEstimator, metaclass=ABCMeta):
self.svd_method, legal_svd_methods
)
)
- check_scalar(self.n_clusters, "n_clusters", target_type=numbers.Integral, min_val=1, max_val=n_samples)
- check_scalar(self.n_init, "n_init", target_type=numbers.Integral, min_val=1)
+ check_scalar(
+ self.n_init,
+ "n_init",
+ target_type=numbers.Integral,
+ min_val=1,
+ include_boundaries="left",
+ )
def fit(self, X, y=None):
"""Creates a biclustering for X.
@@ -128,7 +133,7 @@ class BaseSpectral(BiclusterMixin, BaseEstimator, metaclass=ABCMeta):
"""
X = self._validate_data(X, accept_sparse="csr", dtype=np.float64)
- self._check_parameters(X.shape[0])
+ self._check_parameters(*X.shape)
self._fit(X)
return self
@@ -326,6 +331,17 @@ class SpectralCoclustering(BaseSpectral):
n_clusters, svd_method, n_svd_vecs, mini_batch, init, n_init, random_state
)
+ def _check_parameters(self, n_samples, n_features):
+ super()._check_parameters(n_samples, n_features)
+ check_scalar(
+ self.n_clusters,
+ "n_clusters",
+ target_type=numbers.Integral,
+ min_val=1,
+ max_val=n_samples,
+ include_boundaries="both",
+ )
+
def _fit(self, X):
normalized_data, row_diag, col_diag = _scale_normalize(X)
n_sv = 1 + int(np.ceil(np.log2(self.n_clusters)))
@@ -487,8 +503,8 @@ class SpectralBiclustering(BaseSpectral):
self.n_components = n_components
self.n_best = n_best
- def _check_parameters(self, n_sample):
- super()._check_parameters()
+ def _check_parameters(self, n_samples, n_features):
+ super()._check_parameters(n_samples, n_features)
legal_methods = ("bistochastic", "scale", "log")
if self.method not in legal_methods:
raise ValueError(
@@ -496,22 +512,60 @@ class SpectralBiclustering(BaseSpectral):
self.method, legal_methods
)
)
- try:
- int(self.n_clusters)
- except TypeError:
+
+ n_clusters_type_error = (
+ f"Incorrect parameter n_clusters has value: {self.n_clusters}. It "
+ "should either be a single integer or an iterable with two "
+ "integers: (n_row_clusters, n_column_clusters)"
+ )
+ if isinstance(self.n_clusters, numbers.Integral):
+ check_scalar(
+ self.n_clusters,
+ "n_clusters",
+ target_type=numbers.Integral,
+ min_val=1,
+ max_val=n_samples,
+ include_boundaries="both",
+ )
+ elif isinstance(self.n_clusters, tuple):
try:
- r, c = self.n_clusters
- int(r)
- int(c)
- except (ValueError, TypeError) as e:
- raise ValueError(
- "Incorrect parameter n_clusters has value:"
- " {}. It should either be a single integer"
- " or an iterable with two integers:"
- " (n_row_clusters, n_column_clusters)"
- ) from e
- check_scalar(self.n_components, "n_components", target_type=numbers.Integral, min_val=1)
- check_scalar(self.n_best, "n_best", target_type=numbers.Integral, min_val=1, max_val=self.n_components)
+ rows, columns = self.n_clusters
+ except ValueError as e:
+ raise ValueError(n_clusters_type_error) from e
+ check_scalar(
+ rows,
+ "n_rows from n_clusters",
+ target_type=numbers.Integral,
+ min_val=1,
+ max_val=n_samples,
+ include_boundaries="both",
+ )
+ check_scalar(
+ columns,
+ "n_columns from n_clusters",
+ target_type=numbers.Integral,
+ min_val=1,
+ max_val=n_features,
+ include_boundaries="both",
+ )
+ else:
+ raise TypeError(n_clusters_type_error)
+
+ check_scalar(
+ self.n_components,
+ "n_components",
+ target_type=numbers.Integral,
+ min_val=1,
+ include_boundaries="left",
+ )
+ check_scalar(
+ self.n_best,
+ "n_best",
+ target_type=numbers.Integral,
+ min_val=1,
+ max_val=self.n_components,
+ include_boundaries="both",
+ )
def _fit(self, X):
n_sv = self.n_components
diff --git a/sklearn/cluster/tests/test_bicluster.py b/sklearn/cluster/tests/test_bicluster.py
index ba6d91a537..0eb6a6805e 100644
--- a/sklearn/cluster/tests/test_bicluster.py
+++ b/sklearn/cluster/tests/test_bicluster.py
@@ -208,23 +208,36 @@ def test_perfect_checkerboard():
@pytest.mark.parametrize(
- "args",
+ "args, err_type, err_msg",
[
- {"n_clusters": (3, 3, 3)},
- {"n_clusters": "abc"},
- {"n_clusters": (3, "abc")},
- {"method": "unknown"},
- {"n_components": 0},
- {"n_best": 0},
- {"svd_method": "unknown"},
- {"n_components": 3, "n_best": 4},
+ (
+ {"n_clusters": (3, 3, 3)},
+ ValueError,
+ r"Incorrect parameter n_clusters has value: \(3, 3, 3\)",
+ ),
+ (
+ {"n_clusters": "abc"},
+ TypeError,
+ "Incorrect parameter n_clusters has value: abc",
+ ),
+ (
+ {"n_clusters": (3, "abc")},
+ TypeError,
+ "n_columns from n_clusters must be an instance of <class"
+ " 'numbers.Integral'>, not <class 'str'>.",
+ ),
+ ({"method": "unknown"}, ValueError, "Unknown method: 'unknown'"),
+ ({"n_components": 0}, ValueError, "n_components == 0, must be >= 1."),
+ ({"n_best": 0}, ValueError, "n_best == 0, must be >= 1."),
+ ({"svd_method": "unknown"}, ValueError, "Unknown SVD method: 'unknown'"),
+ ({"n_components": 3, "n_best": 4}, ValueError, "n_best == 4, must be <= 3."),
],
)
-def test_errors(args):
+def test_errors(args, err_type, err_msg):
data = np.arange(25).reshape((5, 5))
model = SpectralBiclustering(**args)
- with pytest.raises(ValueError):
+ with pytest.raises(err_type, match=err_msg):
model.fit(data)
In short, this patch does:
@creatornadiran Could you apply these changes such that I can make another round of reviews with these changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the "Request changes" to see that we already reviewed this PR.
|
Wouldn't it be a bit difficult to check n_jobs with check_scalar? |
if self.n_jobs is not None:
check_scalar(self.n_jobs, "n_jobs", numbers.Integral) This should be enough |
Be aware that the CIs are failing. You should check the logs. |
Be sure to apply |
I got "'SpectralCoclustering' object has no attribute 'n_jobs'" error. |
Finally all CLs are passed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @creatornadiran !
Thanks for the reply! I missed that the But I can't get why should I add |
Currently, |
I understand now thanks. I made the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comment regarding n_jobs
, otherwise LGTM!
Thanks for ckeck and approve! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a linting error, which can be fixed by running black .
I ran the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
super()._check_parameters is already checking svd_method so no need to this block of code Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Still getting 2 errors and don't understand why.
|
The string that you passed was not matching the error message. You can check the fix here: cda955d |
CIs are green. Let's merge then. Thanks @creatornadiran |
Wow, I didn't notice. Thanks for the fix. |
Reference Issues/PRs
Reference Issue #20724
PR #20723
What does this implement/fix? Explain your changes.
Used check_scalar function instead of if-else blocks to validate parameters.
Any other comments?
Please let me know if there is a mistake or you have a suggestion.