From e01314b4aa6dfd07c7fd30d12e3dd0182bc6a488 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 3 Feb 2021 17:32:20 -0500 Subject: [PATCH 1/2] ENH Checks n_features_in_ in discriminant_analysis --- sklearn/discriminant_analysis.py | 5 ++--- sklearn/tests/test_common.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index 1e82578e2693b..fb7e7910f3f33 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -18,7 +18,6 @@ from .linear_model._base import LinearClassifierMixin from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance from .utils.multiclass import unique_labels -from .utils import check_array from .utils.validation import check_is_fitted from .utils.multiclass import check_classification_targets from .utils.extmath import softmax @@ -586,7 +585,7 @@ def transform(self, X): "solver (use 'svd' or 'eigen').") check_is_fitted(self) - X = check_array(X) + X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False) if self.solver == 'svd': X_new = np.dot(X - self.xbar_, self.scalings_) elif self.solver == 'eigen': @@ -824,7 +823,7 @@ def _decision_function(self, X): # return log posterior, see eq (4.12) p. 110 of the ESL. check_is_fitted(self) - X = check_array(X) + X = self._validate_data(X, reset=False) norm2 = [] for i in range(len(self.classes_)): R = self.rotations_[i] diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 37b6e666238b8..3c8743518d57f 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -267,7 +267,6 @@ def test_search_cv(estimator, check, request): 'calibration', 'compose', 'covariance', - 'discriminant_analysis', 'ensemble', 'feature_extraction', 'feature_selection', From f8b987fd6a5a636933e12e5acbd71ac6615573b4 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 4 Feb 2021 10:02:14 -0500 Subject: [PATCH 2/2] ENH Remove dtype casting --- sklearn/discriminant_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index fb7e7910f3f33..c5c18ac9136d2 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -585,7 +585,7 @@ def transform(self, X): "solver (use 'svd' or 'eigen').") check_is_fitted(self) - X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False) + X = self._validate_data(X, reset=False) if self.solver == 'svd': X_new = np.dot(X - self.xbar_, self.scalings_) elif self.solver == 'eigen':