15
15
from ..base import BaseEstimator
16
16
from ..base import DensityMixin
17
17
from ..exceptions import ConvergenceWarning
18
- from ..utils import check_array , check_random_state
18
+ from ..utils import check_random_state
19
19
from ..utils .validation import check_is_fitted
20
20
21
21
@@ -36,32 +36,6 @@ def _check_shape(param, param_shape, name):
36
36
"but got %s" % (name , param_shape , param .shape ))
37
37
38
38
39
- def _check_X (X , n_components = None , n_features = None , ensure_min_samples = 1 ):
40
- """Check the input data X.
41
-
42
- Parameters
43
- ----------
44
- X : array-like of shape (n_samples, n_features)
45
-
46
- n_components : int
47
-
48
- Returns
49
- -------
50
- X : array, shape (n_samples, n_features)
51
- """
52
- X = check_array (X , dtype = [np .float64 , np .float32 ],
53
- ensure_min_samples = ensure_min_samples )
54
- if n_components is not None and X .shape [0 ] < n_components :
55
- raise ValueError ('Expected n_samples >= n_components '
56
- 'but got n_components = %d, n_samples = %d'
57
- % (n_components , X .shape [0 ]))
58
- if n_features is not None and X .shape [1 ] != n_features :
59
- raise ValueError ("Expected the input data X have %d features, "
60
- "but got %d features"
61
- % (n_features , X .shape [1 ]))
62
- return X
63
-
64
-
65
39
class BaseMixture (DensityMixin , BaseEstimator , metaclass = ABCMeta ):
66
40
"""Base class for mixture models.
67
41
@@ -217,8 +191,12 @@ def fit_predict(self, X, y=None):
217
191
labels : array, shape (n_samples,)
218
192
Component labels.
219
193
"""
220
- X = _check_X (X , self .n_components , ensure_min_samples = 2 )
221
- self ._check_n_features (X , reset = True )
194
+ X = self ._validate_data (X , dtype = [np .float64 , np .float32 ],
195
+ ensure_min_samples = 2 )
196
+ if X .shape [0 ] < self .n_components :
197
+ raise ValueError ("Expected n_samples >= n_components "
198
+ f"but got n_components = { self .n_components } , "
199
+ f"n_samples = { X .shape [0 ]} " )
222
200
self ._check_initial_parameters (X )
223
201
224
202
# if we enable warm_start, we will have a unique initialisation
@@ -335,7 +313,7 @@ def score_samples(self, X):
335
313
Log probabilities of each data point in X.
336
314
"""
337
315
check_is_fitted (self )
338
- X = _check_X (X , None , self . means_ . shape [ 1 ] )
316
+ X = self . _validate_data (X , reset = False )
339
317
340
318
return logsumexp (self ._estimate_weighted_log_prob (X ), axis = 1 )
341
319
@@ -370,7 +348,7 @@ def predict(self, X):
370
348
Component labels.
371
349
"""
372
350
check_is_fitted (self )
373
- X = _check_X (X , None , self . means_ . shape [ 1 ] )
351
+ X = self . _validate_data (X , reset = False )
374
352
return self ._estimate_weighted_log_prob (X ).argmax (axis = 1 )
375
353
376
354
def predict_proba (self , X ):
@@ -389,7 +367,7 @@ def predict_proba(self, X):
389
367
the model given each sample.
390
368
"""
391
369
check_is_fitted (self )
392
- X = _check_X (X , None , self . means_ . shape [ 1 ] )
370
+ X = self . _validate_data (X , reset = False )
393
371
_ , log_resp = self ._estimate_log_prob_resp (X )
394
372
return np .exp (log_resp )
395
373
0 commit comments