Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3bc8862

Browse files
thomasjpfanMarcBresson
authored andcommitted
FIX Raises error in PLSRegression for invalid n_components (scikit-learn#29710)
1 parent 2333813 commit 3bc8862

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

doc/whats_new/v1.6.rst

+6
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ Changelog
172172
now accepts string format or callable to generate feature names. :pr:`28934` by
173173
:user:`Marc Bresson <MarcBresson>`.
174174

175+
:mod:`sklearn.cross_decomposition`
176+
..................................
177+
178+
- |Fix| :class:`cross_decomposition.PLSRegression` properly raises an error when
179+
`n_components` is larger than `n_samples`. :pr:`29710` by `Thomas Fan`_.
180+
175181
:mod:`sklearn.datasets`
176182
.......................
177183

sklearn/cross_decomposition/_pls.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ def fit(self, X, y=None, Y=None):
291291
# With PLSRegression n_components is bounded by the rank of (X.T X) see
292292
# Wegelin page 25. With CCA and PLSCanonical, n_components is bounded
293293
# by the rank of X and the rank of Y: see Wegelin page 12
294-
rank_upper_bound = p if self.deflation_mode == "regression" else min(n, p, q)
294+
rank_upper_bound = (
295+
min(n, p) if self.deflation_mode == "regression" else min(n, p, q)
296+
)
295297
if n_components > rank_upper_bound:
296298
raise ValueError(
297299
f"`n_components` upper bound is {rank_upper_bound}. "

sklearn/cross_decomposition/tests/test_pls.py

+11
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,17 @@ def test_n_components_upper_bounds(Estimator):
480480
est.fit(X, Y)
481481

482482

483+
def test_n_components_upper_PLSRegression():
484+
"""Check the validation of `n_components` upper bounds for PLSRegression."""
485+
rng = np.random.RandomState(0)
486+
X = rng.randn(20, 64)
487+
Y = rng.randn(20, 3)
488+
est = PLSRegression(n_components=30)
489+
err_msg = "`n_components` upper bound is 20. Got 30 instead. Reduce `n_components`."
490+
with pytest.raises(ValueError, match=err_msg):
491+
est.fit(X, Y)
492+
493+
483494
@pytest.mark.parametrize("n_samples, n_features", [(100, 10), (100, 200)])
484495
def test_singular_value_helpers(n_samples, n_features, global_random_seed):
485496
# Make sure SVD and power method give approximately the same results

0 commit comments

Comments
 (0)