From 8bf05bdaf07cbf684db9b5b43f5a0512527d1b89 Mon Sep 17 00:00:00 2001 From: Patrick Date: Sun, 29 Oct 2017 23:47:37 +0000 Subject: [PATCH 1/6] Add check for n_components in pca --- sklearn/decomposition/pca.py | 7 +++++++ sklearn/decomposition/tests/test_pca.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index cbd688f3d748d..6475d63ef8f68 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -383,6 +383,13 @@ def _fit(self, X): else: n_components = self.n_components + if n_components != "mle" and \ + (n_components > 1 and + not (np.issubdtype(type(n_components), np.integer))): + raise ValueError("n_components=%r must be of type int " + "when bigger than 1, was of type=%r" + % (n_components, type(n_components))) + # Handle svd_solver svd_solver = self.svd_solver if svd_solver == 'auto': diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index aa67189407296..c2cbc1e74ec49 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -9,6 +9,7 @@ from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_raises_regex +from sklearn.utils.testing import assert_raise_message from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import ignore_warnings @@ -389,6 +390,12 @@ def test_pca_validation(): PCA(n_components, svd_solver=solver) .fit, data) + n_components = 1.2 + type_ncom = type(n_components) + assert_raise_message(ValueError, "n_components={} must be of type int " + "when bigger than 1, was of type={}" + .format(n_components, type_ncom), + PCA(n_components).fit, data) def test_n_components_none(): # Ensures that n_components == None is handled correctly From c8fab7e4dd4d84ef01b20af44aec012445d4c70b Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 30 Oct 2017 10:34:03 +0000 Subject: [PATCH 2/6] Add more consistency to checks and more tests --- sklearn/decomposition/pca.py | 16 +++++++++------- sklearn/decomposition/tests/test_pca.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index 6475d63ef8f68..c2b21f27101b2 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -383,13 +383,6 @@ def _fit(self, X): else: n_components = self.n_components - if n_components != "mle" and \ - (n_components > 1 and - not (np.issubdtype(type(n_components), np.integer))): - raise ValueError("n_components=%r must be of type int " - "when bigger than 1, was of type=%r" - % (n_components, type(n_components))) - # Handle svd_solver svd_solver = self.svd_solver if svd_solver == 'auto': @@ -424,6 +417,11 @@ def _fit_full(self, X, n_components): "min(n_samples, n_features)=%r with " "svd_solver='full'" % (n_components, min(n_samples, n_features))) + elif n_components >= 1: + if not np.issubdtype(type(n_components), np.integer): + raise ValueError("n_components=%r must be of type int " + "when greater or equal to 1, was of type=%r" + % (n_components, type(n_components))) # Center data self.mean_ = np.mean(X, axis=0) @@ -484,6 +482,10 @@ def _fit_truncated(self, X, n_components, svd_solver): "svd_solver='%s'" % (n_components, min(n_samples, n_features), svd_solver)) + elif not np.issubdtype(type(n_components), np.integer): + raise ValueError("n_components=%r must be of type int " + "when greater or equal to 1, was of type=%r" + % (n_components, type(n_components))) elif svd_solver == 'arpack' and n_components == min(n_samples, n_features): raise ValueError("n_components=%r must be strictly less than " diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index c2cbc1e74ec49..fc83c2ef94e6e 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -390,13 +390,17 @@ def test_pca_validation(): PCA(n_components, svd_solver=solver) .fit, data) - n_components = 1.2 - type_ncom = type(n_components) - assert_raise_message(ValueError, "n_components={} must be of type int " - "when bigger than 1, was of type={}" + n_components = 1.0 + type_ncom = type(n_components) + assert_raise_message(ValueError, + "n_components={} must be of type int " + "when greater or equal to 1, " + "was of type={}" .format(n_components, type_ncom), PCA(n_components).fit, data) + + def test_n_components_none(): # Ensures that n_components == None is handled correctly X = iris.data From 23df5f7682aaa238bc700b43c11a78bf51f139fe Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 30 Oct 2017 10:38:20 +0000 Subject: [PATCH 3/6] Fix mistake in identation --- sklearn/decomposition/tests/test_pca.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index fc83c2ef94e6e..a8583519b25a6 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -390,14 +390,13 @@ def test_pca_validation(): PCA(n_components, svd_solver=solver) .fit, data) - n_components = 1.0 - type_ncom = type(n_components) - assert_raise_message(ValueError, - "n_components={} must be of type int " - "when greater or equal to 1, " - "was of type={}" - .format(n_components, type_ncom), - PCA(n_components).fit, data) + n_components = 1.0 + type_ncom = type(n_components) + assert_raise_message(ValueError, + "n_components={} must be of type int " + "when greater or equal to 1, was of type={}" + .format(n_components, type_ncom), + PCA(n_components).fit, data) From 11824dc6b9d323898459d52753f5c7d8030c8c52 Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 30 Oct 2017 11:04:09 +0000 Subject: [PATCH 4/6] Fix typo and bug in solver selection in test --- sklearn/decomposition/pca.py | 5 +++-- sklearn/decomposition/tests/test_pca.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index c2b21f27101b2..a6afbd09da3a1 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -420,7 +420,8 @@ def _fit_full(self, X, n_components): elif n_components >= 1: if not np.issubdtype(type(n_components), np.integer): raise ValueError("n_components=%r must be of type int " - "when greater or equal to 1, was of type=%r" + "when greater than or equal to 1, " + "was of type=%r" % (n_components, type(n_components))) # Center data @@ -484,7 +485,7 @@ def _fit_truncated(self, X, n_components, svd_solver): svd_solver)) elif not np.issubdtype(type(n_components), np.integer): raise ValueError("n_components=%r must be of type int " - "when greater or equal to 1, was of type=%r" + "when greater than or equal to 1, was of type=%r" % (n_components, type(n_components))) elif svd_solver == 'arpack' and n_components == min(n_samples, n_features): diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index a8583519b25a6..9285c9e46991f 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -394,9 +394,9 @@ def test_pca_validation(): type_ncom = type(n_components) assert_raise_message(ValueError, "n_components={} must be of type int " - "when greater or equal to 1, was of type={}" + "when greater than or equal to 1, was of type={}" .format(n_components, type_ncom), - PCA(n_components).fit, data) + PCA(n_components, svd_solver=solver).fit, data) From 3b269741b31568a193cd1aca0c5212b35e367c5c Mon Sep 17 00:00:00 2001 From: CoderPat Date: Mon, 30 Oct 2017 16:25:03 +0000 Subject: [PATCH 5/6] Add more consistency to type checking --- sklearn/decomposition/pca.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index a6afbd09da3a1..75d532bbbb6e1 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -418,7 +418,7 @@ def _fit_full(self, X, n_components): "svd_solver='full'" % (n_components, min(n_samples, n_features))) elif n_components >= 1: - if not np.issubdtype(type(n_components), np.integer): + if not isinstance(n_components, (int, np.integer)): raise ValueError("n_components=%r must be of type int " "when greater than or equal to 1, " "was of type=%r" @@ -483,7 +483,7 @@ def _fit_truncated(self, X, n_components, svd_solver): "svd_solver='%s'" % (n_components, min(n_samples, n_features), svd_solver)) - elif not np.issubdtype(type(n_components), np.integer): + elif not isinstance(n_components, (int, np.integer)): raise ValueError("n_components=%r must be of type int " "when greater than or equal to 1, was of type=%r" % (n_components, type(n_components))) From d1be5304924218f101c655d5869d8c2cf0248ce2 Mon Sep 17 00:00:00 2001 From: CoderPat Date: Mon, 30 Oct 2017 19:41:42 +0000 Subject: [PATCH 6/6] fix type checking for compatibilty with python2 long --- sklearn/decomposition/pca.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index 75d532bbbb6e1..fccc702d2c2ce 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -11,6 +11,7 @@ # License: BSD 3 clause from math import log, sqrt +import numbers import numpy as np from scipy import linalg @@ -418,7 +419,7 @@ def _fit_full(self, X, n_components): "svd_solver='full'" % (n_components, min(n_samples, n_features))) elif n_components >= 1: - if not isinstance(n_components, (int, np.integer)): + if not isinstance(n_components, (numbers.Integral, np.integer)): raise ValueError("n_components=%r must be of type int " "when greater than or equal to 1, " "was of type=%r" @@ -483,7 +484,7 @@ def _fit_truncated(self, X, n_components, svd_solver): "svd_solver='%s'" % (n_components, min(n_samples, n_features), svd_solver)) - elif not isinstance(n_components, (int, np.integer)): + elif not isinstance(n_components, (numbers.Integral, np.integer)): raise ValueError("n_components=%r must be of type int " "when greater than or equal to 1, was of type=%r" % (n_components, type(n_components)))