diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 5d49a6fd281c7..daec37e96ac16 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -110,6 +110,13 @@ Changelog reconstruction of a `X` target when a `Y` parameter is given. :pr:`19680` by :user:`Robin Thibaut `. +:mod:`sklearn.discriminant_analysis` +.................................... + +- |API| Adds :term:`get_feature_names_out` to + :class:`discriminant_analysis.LinearDiscriminantAnalysis`. :pr:`22120` by + `Thomas Fan`_. + :mod:`sklearn.feature_extraction` ................................. diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index 79faa8694a535..65960bd044a30 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -15,6 +15,7 @@ from scipy.special import expit from .base import BaseEstimator, TransformerMixin, ClassifierMixin +from .base import _ClassNamePrefixFeaturesOutMixin from .linear_model._base import LinearClassifierMixin from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance from .utils.multiclass import unique_labels @@ -165,7 +166,10 @@ def _class_cov(X, y, priors, shrinkage=None, covariance_estimator=None): class LinearDiscriminantAnalysis( - LinearClassifierMixin, TransformerMixin, BaseEstimator + _ClassNamePrefixFeaturesOutMixin, + LinearClassifierMixin, + TransformerMixin, + BaseEstimator, ): """Linear Discriminant Analysis. @@ -614,6 +618,7 @@ def fit(self, X, y): self.intercept_ = np.array( self.intercept_[1] - self.intercept_[0], ndmin=1, dtype=X.dtype ) + self._n_features_out = self._max_components return self def transform(self, X): diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 3611d52085b9e..2a0bb8d085b77 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -382,7 +382,6 @@ def test_pandas_column_name_consistency(estimator): GET_FEATURES_OUT_MODULES_TO_IGNORE = [ "cluster", "cross_decomposition", - "discriminant_analysis", "ensemble", "isotonic", "kernel_approximation", diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index 90e5b57a63779..40ede1feba547 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -650,3 +650,20 @@ def test_raises_value_error_on_same_number_of_classes_and_samples(solver): clf = LinearDiscriminantAnalysis(solver=solver) with pytest.raises(ValueError, match="The number of samples must be more"): clf.fit(X, y) + + +def test_get_feature_names_out(): + """Check get_feature_names_out uses class name as prefix.""" + + est = LinearDiscriminantAnalysis().fit(X, y) + names_out = est.get_feature_names_out() + + class_name_lower = "LinearDiscriminantAnalysis".lower() + expected_names_out = np.array( + [ + f"{class_name_lower}{i}" + for i in range(est.explained_variance_ratio_.shape[0]) + ], + dtype=object, + ) + assert_array_equal(names_out, expected_names_out)