|
16 | 16 | from scipy.interpolate import interp1d
|
17 | 17 |
|
18 | 18 | from .base import _preprocess_data
|
19 |
| -from ..base import BaseEstimator, TransformerMixin |
| 19 | +from ..base import BaseEstimator |
20 | 20 | from ..externals import six
|
21 | 21 | from ..externals.joblib import Memory, Parallel, delayed
|
22 |
| -from ..utils import (as_float_array, check_random_state, check_X_y, |
23 |
| - check_array, safe_mask) |
| 22 | +from ..feature_selection.base import SelectorMixin |
| 23 | +from ..utils import (as_float_array, check_random_state, check_X_y, safe_mask) |
24 | 24 | from ..utils.validation import check_is_fitted
|
25 | 25 | from .least_angle import lars_path, LassoLarsIC
|
26 | 26 | from .logistic import LogisticRegression
|
@@ -59,7 +59,7 @@ def _resample_model(estimator_func, X, y, scaling=.5, n_resampling=200,
|
59 | 59 |
|
60 | 60 |
|
61 | 61 | class BaseRandomizedLinearModel(six.with_metaclass(ABCMeta, BaseEstimator,
|
62 |
| - TransformerMixin)): |
| 62 | + SelectorMixin)): |
63 | 63 | """Base class to implement randomized linear models for feature selection
|
64 | 64 |
|
65 | 65 | This implements the strategy by Meinshausen and Buhlman:
|
@@ -87,7 +87,7 @@ def fit(self, X, y):
|
87 | 87 | Returns
|
88 | 88 | -------
|
89 | 89 | self : object
|
90 |
| - Returns an instance of self. |
| 90 | + Returns an instance of self. |
91 | 91 | """
|
92 | 92 | X, y = check_X_y(X, y, ['csr', 'csc'], y_numeric=True,
|
93 | 93 | ensure_min_samples=2, estimator=self)
|
@@ -121,31 +121,17 @@ def _make_estimator_and_params(self, X, y):
|
121 | 121 | """Return the parameters passed to the estimator"""
|
122 | 122 | raise NotImplementedError
|
123 | 123 |
|
124 |
| - def get_support(self, indices=False): |
125 |
| - """Return a mask, or list, of the features/indices selected.""" |
126 |
| - check_is_fitted(self, 'scores_') |
| 124 | + def _get_support_mask(self): |
| 125 | + """Get the boolean mask indicating which features are selected. |
127 | 126 |
|
128 |
| - mask = self.scores_ > self.selection_threshold |
129 |
| - return mask if not indices else np.where(mask)[0] |
130 |
| - |
131 |
| - # XXX: the two function below are copy/pasted from feature_selection, |
132 |
| - # Should we add an intermediate base class? |
133 |
| - def transform(self, X): |
134 |
| - """Transform a new matrix using the selected features""" |
135 |
| - mask = self.get_support() |
136 |
| - X = check_array(X) |
137 |
| - if len(mask) != X.shape[1]: |
138 |
| - raise ValueError("X has a different shape than during fitting.") |
139 |
| - return check_array(X)[:, safe_mask(X, mask)] |
140 |
| - |
141 |
| - def inverse_transform(self, X): |
142 |
| - """Transform a new matrix using the selected features""" |
143 |
| - support = self.get_support() |
144 |
| - if X.ndim == 1: |
145 |
| - X = X[None, :] |
146 |
| - Xt = np.zeros((X.shape[0], support.size)) |
147 |
| - Xt[:, support] = X |
148 |
| - return Xt |
| 127 | + Returns |
| 128 | + ------- |
| 129 | + support : boolean array of shape [# input features] |
| 130 | + An element is True iff its corresponding feature is selected |
| 131 | + for retention. |
| 132 | + """ |
| 133 | + check_is_fitted(self, 'scores_') |
| 134 | + return self.scores_ > self.selection_threshold |
149 | 135 |
|
150 | 136 |
|
151 | 137 | ###############################################################################
|
|
0 commit comments