Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cf296c7

Browse files
authored
ENH Checks n_features_in_ after fitting in mixture (#19540)
1 parent 3e64e9e commit cf296c7

File tree

7 files changed

+36
-58
lines changed

7 files changed

+36
-58
lines changed

sklearn/mixture/_base.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..base import BaseEstimator
1616
from ..base import DensityMixin
1717
from ..exceptions import ConvergenceWarning
18-
from ..utils import check_array, check_random_state
18+
from ..utils import check_random_state
1919
from ..utils.validation import check_is_fitted
2020

2121

@@ -36,32 +36,6 @@ def _check_shape(param, param_shape, name):
3636
"but got %s" % (name, param_shape, param.shape))
3737

3838

39-
def _check_X(X, n_components=None, n_features=None, ensure_min_samples=1):
40-
"""Check the input data X.
41-
42-
Parameters
43-
----------
44-
X : array-like of shape (n_samples, n_features)
45-
46-
n_components : int
47-
48-
Returns
49-
-------
50-
X : array, shape (n_samples, n_features)
51-
"""
52-
X = check_array(X, dtype=[np.float64, np.float32],
53-
ensure_min_samples=ensure_min_samples)
54-
if n_components is not None and X.shape[0] < n_components:
55-
raise ValueError('Expected n_samples >= n_components '
56-
'but got n_components = %d, n_samples = %d'
57-
% (n_components, X.shape[0]))
58-
if n_features is not None and X.shape[1] != n_features:
59-
raise ValueError("Expected the input data X have %d features, "
60-
"but got %d features"
61-
% (n_features, X.shape[1]))
62-
return X
63-
64-
6539
class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
6640
"""Base class for mixture models.
6741
@@ -217,8 +191,12 @@ def fit_predict(self, X, y=None):
217191
labels : array, shape (n_samples,)
218192
Component labels.
219193
"""
220-
X = _check_X(X, self.n_components, ensure_min_samples=2)
221-
self._check_n_features(X, reset=True)
194+
X = self._validate_data(X, dtype=[np.float64, np.float32],
195+
ensure_min_samples=2)
196+
if X.shape[0] < self.n_components:
197+
raise ValueError("Expected n_samples >= n_components "
198+
f"but got n_components = {self.n_components}, "
199+
f"n_samples = {X.shape[0]}")
222200
self._check_initial_parameters(X)
223201

224202
# if we enable warm_start, we will have a unique initialisation
@@ -335,7 +313,7 @@ def score_samples(self, X):
335313
Log probabilities of each data point in X.
336314
"""
337315
check_is_fitted(self)
338-
X = _check_X(X, None, self.means_.shape[1])
316+
X = self._validate_data(X, reset=False)
339317

340318
return logsumexp(self._estimate_weighted_log_prob(X), axis=1)
341319

@@ -370,7 +348,7 @@ def predict(self, X):
370348
Component labels.
371349
"""
372350
check_is_fitted(self)
373-
X = _check_X(X, None, self.means_.shape[1])
351+
X = self._validate_data(X, reset=False)
374352
return self._estimate_weighted_log_prob(X).argmax(axis=1)
375353

376354
def predict_proba(self, X):
@@ -389,7 +367,7 @@ def predict_proba(self, X):
389367
the model given each sample.
390368
"""
391369
check_is_fitted(self)
392-
X = _check_X(X, None, self.means_.shape[1])
370+
X = self._validate_data(X, reset=False)
393371
_, log_resp = self._estimate_log_prob_resp(X)
394372
return np.exp(log_resp)
395373

sklearn/mixture/_bayesian_mixture.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,11 @@ class BayesianGaussianMixture(BaseMixture):
288288
(n_features) if 'diag',
289289
float if 'spherical'
290290
291+
n_features_in_ : int
292+
Number of features seen during :term:`fit`.
293+
294+
.. versionadded:: 0.24
295+
291296
Examples
292297
--------
293298
>>> import numpy as np

sklearn/mixture/_gaussian_mixture.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,11 @@ class GaussianMixture(BaseMixture):
582582
Lower bound value on the log-likelihood (of the training data with
583583
respect to the model) of the best fit of EM.
584584
585+
n_features_in_ : int
586+
Number of features seen during :term:`fit`.
587+
588+
.. versionadded:: 0.24
589+
585590
Examples
586591
--------
587592
>>> import numpy as np

sklearn/mixture/tests/test_gaussian_mixture.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -172,30 +172,6 @@ def test_gaussian_mixture_attributes():
172172
assert gmm.init_params == init_params
173173

174174

175-
def test_check_X():
176-
from sklearn.mixture._base import _check_X
177-
rng = np.random.RandomState(0)
178-
179-
n_samples, n_components, n_features = 10, 2, 2
180-
181-
X_bad_dim = rng.rand(n_components - 1, n_features)
182-
assert_raise_message(ValueError,
183-
'Expected n_samples >= n_components '
184-
'but got n_components = %d, n_samples = %d'
185-
% (n_components, X_bad_dim.shape[0]),
186-
_check_X, X_bad_dim, n_components)
187-
188-
X_bad_dim = rng.rand(n_components, n_features + 1)
189-
assert_raise_message(ValueError,
190-
'Expected the input data X have %d features, '
191-
'but got %d features'
192-
% (n_features, X_bad_dim.shape[1]),
193-
_check_X, X_bad_dim, n_components, n_features)
194-
195-
X = rng.rand(n_samples, n_features)
196-
assert_array_equal(X, _check_X(X, n_components, n_features))
197-
198-
199175
def test_check_weights():
200176
rng = np.random.RandomState(0)
201177
rand_data = RandomData(rng)

sklearn/mixture/tests/test_mixture.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,19 @@ def test_gaussian_mixture_n_iter(estimator):
2121
estimator.set_params(max_iter=max_iter)
2222
estimator.fit(X)
2323
assert estimator.n_iter_ == max_iter
24+
25+
26+
@pytest.mark.parametrize(
27+
"estimator",
28+
[GaussianMixture(),
29+
BayesianGaussianMixture()]
30+
)
31+
def test_mixture_n_components_greater_than_n_samples_error(estimator):
32+
"""Check error when n_components <= n_samples"""
33+
rng = np.random.RandomState(0)
34+
X = rng.rand(10, 5)
35+
estimator.set_params(n_components=12)
36+
37+
msg = "Expected n_samples >= n_components"
38+
with pytest.raises(ValueError, match=msg):
39+
estimator.fit(X)

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def test_search_cv(estimator, check, request):
264264
'calibration',
265265
'compose',
266266
'feature_extraction',
267-
'mixture',
268267
'model_selection',
269268
'multiclass',
270269
'multioutput',

sklearn/tests/test_docstring_parameters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def _construct_searchcv_instance(SearchCV):
193193
'kernel_ridge',
194194
'linear_model',
195195
'manifold',
196-
'mixture',
197196
'model_selection',
198197
'multiclass',
199198
'multioutput',

0 commit comments

Comments
 (0)