diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index cbd688f3d748d..fccc702d2c2ce 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -11,6 +11,7 @@ # License: BSD 3 clause from math import log, sqrt +import numbers import numpy as np from scipy import linalg @@ -417,6 +418,12 @@ def _fit_full(self, X, n_components): "min(n_samples, n_features)=%r with " "svd_solver='full'" % (n_components, min(n_samples, n_features))) + elif n_components >= 1: + if not isinstance(n_components, (numbers.Integral, np.integer)): + raise ValueError("n_components=%r must be of type int " + "when greater than or equal to 1, " + "was of type=%r" + % (n_components, type(n_components))) # Center data self.mean_ = np.mean(X, axis=0) @@ -477,6 +484,10 @@ def _fit_truncated(self, X, n_components, svd_solver): "svd_solver='%s'" % (n_components, min(n_samples, n_features), svd_solver)) + elif not isinstance(n_components, (numbers.Integral, np.integer)): + raise ValueError("n_components=%r must be of type int " + "when greater than or equal to 1, was of type=%r" + % (n_components, type(n_components))) elif svd_solver == 'arpack' and n_components == min(n_samples, n_features): raise ValueError("n_components=%r must be strictly less than " diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index aa67189407296..9285c9e46991f 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -9,6 +9,7 @@ from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_raises_regex +from sklearn.utils.testing import assert_raise_message from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import ignore_warnings @@ -389,6 +390,15 @@ def test_pca_validation(): PCA(n_components, svd_solver=solver) .fit, data) + n_components = 1.0 + type_ncom = type(n_components) + assert_raise_message(ValueError, + "n_components={} must be of type int " + "when greater than or equal to 1, was of type={}" + .format(n_components, type_ncom), + PCA(n_components, svd_solver=solver).fit, data) + + def test_n_components_none(): # Ensures that n_components == None is handled correctly