From 9a9cbfe4f46eb2f8b31ebf9e11213fd58f07cd1f Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sat, 12 Aug 2017 07:02:42 -0500 Subject: [PATCH 01/10] Modifies model_selection.cross_validate docstring (#9534) - Fixes rendering of docstring examples - Instead of importing cross_val_score in example, cross_validate is imported --- sklearn/model_selection/_validation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 147d741b500b9..f8c62982aafec 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -144,7 +144,7 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, Examples -------- >>> from sklearn import datasets, linear_model - >>> from sklearn.model_selection import cross_val_score + >>> from sklearn.model_selection import cross_validate >>> from sklearn.metrics.scorer import make_scorer >>> from sklearn.metrics import confusion_matrix >>> from sklearn.svm import LinearSVC @@ -153,15 +153,17 @@ def cross_validate(estimator, X, y=None, groups=None, scoring=None, cv=None, >>> y = diabetes.target[:150] >>> lasso = linear_model.Lasso() - # single metric evaluation using cross_validate + Single metric evaluation using ``cross_validate`` + >>> cv_results = cross_validate(lasso, X, y, return_train_score=False) >>> sorted(cv_results.keys()) # doctest: +ELLIPSIS ['fit_time', 'score_time', 'test_score'] >>> cv_results['test_score'] # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE array([ 0.33..., 0.08..., 0.03...]) - # Multiple metric evaluation using cross_validate - # (Please refer the ``scoring`` parameter doc for more information) + Multiple metric evaluation using ``cross_validate`` + (please refer the ``scoring`` parameter doc for more information) + >>> scores = cross_validate(lasso, X, y, ... scoring=('r2', 'neg_mean_squared_error')) >>> print(scores['test_neg_mean_squared_error']) # doctest: +ELLIPSIS From 8d55a60606262ab26cbcd9481de17d490247d7f6 Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Mon, 14 Aug 2017 14:11:56 -0400 Subject: [PATCH 02/10] raise error on complex data input to estimators --- sklearn/utils/tests/test_validation.py | 12 ++++++++++++ sklearn/utils/validation.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 1fe27f199ac63..7b6b77d19aab5 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -432,6 +432,18 @@ def test_check_array_min_samples_and_features_messages(): assert_array_equal(X, X_checked) assert_array_equal(y, y_checked) +def test_check_array_complex_data_error(): + X = np.array([[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]]) + assert_raises(ValueError, check_array, X) + X = [[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]] + assert_raises(ValueError, check_array, X) + X = ((1+2j, 3+4j, 5+7j), (2+3j, 4+5j, 6+7j)) + assert_raises(ValueError, check_array, X) + X = [np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])] + assert_raises(ValueError, check_array, X) + X = (np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])) + assert_raises(ValueError, check_array, X) + def test_has_fit_parameter(): assert_false(has_fit_parameter(KNeighborsClassifier, "sample_weight")) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 460f20673feaf..f963e320ea98e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -366,10 +366,19 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, dtype_numeric = isinstance(dtype, six.string_types) and dtype == "numeric" dtype_orig = getattr(array, "dtype", None) + if not hasattr(dtype_orig, 'kind'): # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None + if dtype_orig is not None and dtype_orig.kind == "c": + raise ValueError("Complex data is not supported\n{}\n".format(array)) + elif isinstance(array, list) or isinstance(array, tuple): + np_array = np.array(array) + if np_array.dtype.kind == "c": + raise ValueError("Complex data is not supported" + "\n{}\n".format(array)) + if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. From cd863c841d058334aa8fa08845a5907a676f4d9d Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Tue, 15 Aug 2017 19:02:05 -0400 Subject: [PATCH 03/10] Raise exception on providing complex data to estimators --- sklearn/utils/tests/test_validation.py | 24 ++++++++++++++++++------ sklearn/utils/validation.py | 16 ++++++++-------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 7b6b77d19aab5..34af2fc06c040 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -434,16 +434,28 @@ def test_check_array_min_samples_and_features_messages(): def test_check_array_complex_data_error(): X = np.array([[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]]) - assert_raises(ValueError, check_array, X) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + X = [[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]] - assert_raises(ValueError, check_array, X) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + X = ((1+2j, 3+4j, 5+7j), (2+3j, 4+5j, 6+7j)) - assert_raises(ValueError, check_array, X) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + X = [np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])] - assert_raises(ValueError, check_array, X) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + X = (np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])) - assert_raises(ValueError, check_array, X) - + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + + X = (np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + + X = MockDataFrame(np.array([[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]])) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) + + X = sp.coo_matrix([[0, 1+2j], [0, 0]]) + assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) def test_has_fit_parameter(): assert_false(has_fit_parameter(KNeighborsClassifier, "sample_weight")) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index f963e320ea98e..a079bed813236 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -276,6 +276,12 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, return spmatrix +def _ensure_non_complex_data(array): + if array.dtype.kind == "c": + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) + + def check_array(array, accept_sparse=False, dtype="numeric", order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, ensure_min_samples=1, ensure_min_features=1, @@ -371,14 +377,6 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None - if dtype_orig is not None and dtype_orig.kind == "c": - raise ValueError("Complex data is not supported\n{}\n".format(array)) - elif isinstance(array, list) or isinstance(array, tuple): - np_array = np.array(array) - if np_array.dtype.kind == "c": - raise ValueError("Complex data is not supported" - "\n{}\n".format(array)) - if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. @@ -405,10 +403,12 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, context = " by %s" % estimator_name if estimator is not None else "" if sp.issparse(array): + _ensure_non_complex_data(array) array = _ensure_sparse_format(array, accept_sparse, dtype, copy, force_all_finite) else: array = np.array(array, dtype=dtype, order=order, copy=copy) + _ensure_non_complex_data(array) if ensure_2d: if array.ndim == 1: From b84e1eeb0f7c0febacc985af2c91f45ef5a30a38 Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Wed, 16 Aug 2017 18:44:51 -0400 Subject: [PATCH 04/10] adding checks to check_estimator for complex data --- sklearn/utils/estimator_checks.py | 10 +++++++++ sklearn/utils/tests/test_validation.py | 1 + sklearn/utils/validation.py | 30 ++++++++++++++++++-------- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c3b066e5e31be..320b8ee873087 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -76,6 +76,7 @@ def _yield_non_meta_checks(name, estimator): yield check_sample_weights_pandas_series yield check_sample_weights_list yield check_estimators_fit_returns_self + yield check_complex_data # Check that all estimator yield informative messages when # trained on empty datasets @@ -457,6 +458,15 @@ def check_dtype_object(name, estimator_orig): msg = "argument must be a string or a number" assert_raises_regex(TypeError, msg, estimator.fit, X, y) +def check_complex_data(name, estimator_orig): + # check that estimators raise an exception on providing complex data + X = np.random.sample(10) + 1j*np.random.sample(10) + X = X.reshape(-1, 1) + y = np.random.sample(10) + 1j*np.random.sample(10) + estimator = clone(estimator_orig) + assert_raises_regex(ValueError, "Complex data not supported", + estimator.fit, X, y) + @ignore_warnings def check_dict_unchanged(name, estimator_orig): diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 34af2fc06c040..c2ab0456332e0 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1,6 +1,7 @@ """Tests for input validation functions""" import warnings +import ipdb from tempfile import NamedTemporaryFile from itertools import product diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index a079bed813236..04d5e648f2ea4 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -13,6 +13,8 @@ import numpy as np import scipy.sparse as sp +from numpy.core.numeric import ComplexWarning +import ipdb from ..externals import six from ..utils.fixes import signature @@ -21,13 +23,12 @@ from ..exceptions import NotFittedError from ..exceptions import DataConversionWarning - FLOAT_DTYPES = (np.float64, np.float32, np.float16) # Silenced by default to reduce verbosity. Turn on at runtime for # performance profiling. warnings.simplefilter('ignore', NonBLASDotWarning) - +warnings.simplefilter('error', ComplexWarning) def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" @@ -276,10 +277,11 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, return spmatrix -def _ensure_non_complex_data(array): - if array.dtype.kind == "c": - raise ValueError("Complex data not supported\n" - "{}\n".format(array)) +def _ensure_no_complex_data(array): + if hasattr(array, 'dtype') and array.dtype is not None \ + and hasattr(array.dtype, 'kind') and array.dtype.kind == "c": + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) def check_array(array, accept_sparse=False, dtype="numeric", order=None, @@ -377,6 +379,8 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None + _ensure_no_complex_data(array) + if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. @@ -403,12 +407,20 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, context = " by %s" % estimator_name if estimator is not None else "" if sp.issparse(array): - _ensure_non_complex_data(array) array = _ensure_sparse_format(array, accept_sparse, dtype, copy, force_all_finite) else: - array = np.array(array, dtype=dtype, order=order, copy=copy) - _ensure_non_complex_data(array) + with warnings.catch_warnings(): + try: + warnings.simplefilter('error', ComplexWarning) + array = np.array(array, dtype=dtype, order=order, copy=copy) + except ComplexWarning: + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) + + if array.dtype.kind == "c": + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) if ensure_2d: if array.ndim == 1: From 7d977f2001018fd254ec2bb463b9151e50e021f3 Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Wed, 16 Aug 2017 18:47:21 -0400 Subject: [PATCH 05/10] removing some unnecessary parts --- sklearn/utils/validation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 04d5e648f2ea4..11f89d57dd0d0 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -14,7 +14,6 @@ import numpy as np import scipy.sparse as sp from numpy.core.numeric import ComplexWarning -import ipdb from ..externals import six from ..utils.fixes import signature @@ -28,7 +27,6 @@ # Silenced by default to reduce verbosity. Turn on at runtime for # performance profiling. warnings.simplefilter('ignore', NonBLASDotWarning) -warnings.simplefilter('error', ComplexWarning) def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" From e4b437f9a5783ca7b8d8983a05a6a336648a037b Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Wed, 16 Aug 2017 18:53:35 -0400 Subject: [PATCH 06/10] autopep8 changes --- sklearn/utils/estimator_checks.py | 7 +-- sklearn/utils/tests/test_validation.py | 61 ++++++++++++++++---------- sklearn/utils/validation.py | 7 +-- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 320b8ee873087..646514205c4fc 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -119,7 +119,7 @@ def _yield_classifier_checks(name, classifier): if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"] and # TODO some complication with -1 label - name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): + name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): # We don't raise a warning in these classifiers, as # the column y interface is used by the forests. @@ -458,11 +458,12 @@ def check_dtype_object(name, estimator_orig): msg = "argument must be a string or a number" assert_raises_regex(TypeError, msg, estimator.fit, X, y) + def check_complex_data(name, estimator_orig): # check that estimators raise an exception on providing complex data - X = np.random.sample(10) + 1j*np.random.sample(10) + X = np.random.sample(10) + 1j * np.random.sample(10) X = X.reshape(-1, 1) - y = np.random.sample(10) + 1j*np.random.sample(10) + y = np.random.sample(10) + 1j * np.random.sample(10) estimator = clone(estimator_orig) assert_raises_regex(ValueError, "Complex data not supported", estimator.fit, X, y) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index c2ab0456332e0..88a7e90c2f39f 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -40,6 +40,7 @@ from sklearn.utils.testing import assert_raise_message + def test_as_float_array(): # Test function for as_float_array X = np.ones((3, 10), dtype=np.int32) @@ -97,7 +98,7 @@ def test_np_matrix(): def test_memmap(): # Confirm that input validation code doesn't copy memory mapped arrays - asflt = lambda x: as_float_array(x, copy=False) + def asflt(x): return as_float_array(x, copy=False) with NamedTemporaryFile(prefix='sklearn-test') as tmp: M = np.memmap(tmp, shape=(10, 10), dtype=np.float32) @@ -433,30 +434,44 @@ def test_check_array_min_samples_and_features_messages(): assert_array_equal(X, X_checked) assert_array_equal(y, y_checked) -def test_check_array_complex_data_error(): - X = np.array([[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]]) - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) - - X = [[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]] - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) - - X = ((1+2j, 3+4j, 5+7j), (2+3j, 4+5j, 6+7j)) - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) - X = [np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])] - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) - - X = (np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])) - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) - - X = (np.array([1+2j, 3+4j, 5+7j]), np.array([2+3j, 4+5j, 6+7j])) - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) - - X = MockDataFrame(np.array([[1+2j, 3+4j, 5+7j], [2+3j, 4+5j, 6+7j]])) - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) +def test_check_array_complex_data_error(): + X = np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = [[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]] + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = ((1 + 2j, 3 + 4j, 5 + 7j), (2 + 3j, 4 + 5j, 6 + 7j)) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = [np.array([1 + 2j, 3 + 4j, 5 + 7j]), + np.array([2 + 3j, 4 + 5j, 6 + 7j])] + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), + np.array([2 + 3j, 4 + 5j, 6 + 7j])) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), + np.array([2 + 3j, 4 + 5j, 6 + 7j])) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = MockDataFrame( + np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]])) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + X = sp.coo_matrix([[0, 1 + 2j], [0, 0]]) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) - X = sp.coo_matrix([[0, 1+2j], [0, 0]]) - assert_raises_regexp(ValueError, "Complex data not supported", check_array, X) def test_has_fit_parameter(): assert_false(has_fit_parameter(KNeighborsClassifier, "sample_weight")) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 11f89d57dd0d0..6ace077c40712 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -28,6 +28,7 @@ # performance profiling. warnings.simplefilter('ignore', NonBLASDotWarning) + def _assert_all_finite(X): """Like assert_all_finite, but only for ndarray.""" if _get_config()['assume_finite']: @@ -277,9 +278,9 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, def _ensure_no_complex_data(array): if hasattr(array, 'dtype') and array.dtype is not None \ - and hasattr(array.dtype, 'kind') and array.dtype.kind == "c": - raise ValueError("Complex data not supported\n" - "{}\n".format(array)) + and hasattr(array.dtype, 'kind') and array.dtype.kind == "c": + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) def check_array(array, accept_sparse=False, dtype="numeric", order=None, From b26e99d4c39c9f800d108cb979849bb040ac56ea Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Wed, 16 Aug 2017 19:48:38 -0400 Subject: [PATCH 07/10] removing ipdb, restoring some autopep8 fixes --- sklearn/utils/estimator_checks.py | 2 +- sklearn/utils/tests/test_validation.py | 3 +-- sklearn/utils/validation.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 646514205c4fc..068902580719f 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -119,7 +119,7 @@ def _yield_classifier_checks(name, classifier): if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"] and # TODO some complication with -1 label - name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): + name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): # We don't raise a warning in these classifiers, as # the column y interface is used by the forests. diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 88a7e90c2f39f..8e75244ebd203 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1,7 +1,6 @@ """Tests for input validation functions""" import warnings -import ipdb from tempfile import NamedTemporaryFile from itertools import product @@ -98,7 +97,7 @@ def test_np_matrix(): def test_memmap(): # Confirm that input validation code doesn't copy memory mapped arrays - def asflt(x): return as_float_array(x, copy=False) + asflt = lambda x: as_float_array(x, copy=False) with NamedTemporaryFile(prefix='sklearn-test') as tmp: M = np.memmap(tmp, shape=(10, 10), dtype=np.float32) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 6ace077c40712..85f9fb5def3c5 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -373,7 +373,6 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, dtype_numeric = isinstance(dtype, six.string_types) and dtype == "numeric" dtype_orig = getattr(array, "dtype", None) - if not hasattr(dtype_orig, 'kind'): # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None From 694b3b3a34b7bf3677871bf64ece0ad9ddde9c0b Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Wed, 16 Aug 2017 19:51:39 -0400 Subject: [PATCH 08/10] removing ipdb, restoring some autopep8 fixes --- sklearn/utils/estimator_checks.py | 2 +- sklearn/utils/tests/test_validation.py | 1 - sklearn/utils/validation.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 068902580719f..fa8e24c588555 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -119,7 +119,7 @@ def _yield_classifier_checks(name, classifier): if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"] and # TODO some complication with -1 label - name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): + name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): # We don't raise a warning in these classifiers, as # the column y interface is used by the forests. diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 8e75244ebd203..59ed3a76403a9 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -39,7 +39,6 @@ from sklearn.utils.testing import assert_raise_message - def test_as_float_array(): # Test function for as_float_array X = np.ones((3, 10), dtype=np.int32) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 85f9fb5def3c5..cffa5c7ef6076 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -22,6 +22,7 @@ from ..exceptions import NotFittedError from ..exceptions import DataConversionWarning + FLOAT_DTYPES = (np.float64, np.float32, np.float16) # Silenced by default to reduce verbosity. Turn on at runtime for From 843230f5dbdac2685ca40a8b73c894627aa592df Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Wed, 16 Aug 2017 20:56:14 -0400 Subject: [PATCH 09/10] adding documentation for complex data handling --- sklearn/utils/validation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index cffa5c7ef6076..4af88a7ed9bb6 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -378,8 +378,6 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, # not a data type (e.g. a column named dtype in a pandas DataFrame) dtype_orig = None - _ensure_no_complex_data(array) - if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float. @@ -406,9 +404,15 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, context = " by %s" % estimator_name if estimator is not None else "" if sp.issparse(array): + _ensure_no_complex_data(array) array = _ensure_sparse_format(array, accept_sparse, dtype, copy, force_all_finite) else: + # If np.array(..) gives ComplexWarning, then we convert the warning + # to an error. This is needed because specifying a non complex + # dtype to the function converts complex to real dtype, + # thereby passing the test made in the lines following the scope + # of warnings context manager. with warnings.catch_warnings(): try: warnings.simplefilter('error', ComplexWarning) @@ -417,9 +421,11 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, raise ValueError("Complex data not supported\n" "{}\n".format(array)) - if array.dtype.kind == "c": - raise ValueError("Complex data not supported\n" - "{}\n".format(array)) + # It is possible that the np.array(..) gave no warning. This happens + # when no dtype conversion happend, for example dtype = None. The + # result is that np.array(..) produces an array of complex dtype + # and we need to catch and raise exception for such cases. + _ensure_no_complex_data(array) if ensure_2d: if array.ndim == 1: From 51d7f47f286d270d4b28aca9d7d2e58d87badf4e Mon Sep 17 00:00:00 2001 From: Pravar Mahajan Date: Thu, 17 Aug 2017 18:48:45 -0400 Subject: [PATCH 10/10] adding one line explanation for each test case --- sklearn/utils/tests/test_validation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 59ed3a76403a9..cfdc03341e8bc 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -434,38 +434,40 @@ def test_check_array_min_samples_and_features_messages(): def test_check_array_complex_data_error(): + # np array X = np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]) assert_raises_regexp( ValueError, "Complex data not supported", check_array, X) + # list of lists X = [[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]] assert_raises_regexp( ValueError, "Complex data not supported", check_array, X) + # tuple of tuples X = ((1 + 2j, 3 + 4j, 5 + 7j), (2 + 3j, 4 + 5j, 6 + 7j)) assert_raises_regexp( ValueError, "Complex data not supported", check_array, X) + # list of np arrays X = [np.array([1 + 2j, 3 + 4j, 5 + 7j]), np.array([2 + 3j, 4 + 5j, 6 + 7j])] assert_raises_regexp( ValueError, "Complex data not supported", check_array, X) + # tuple of np arrays X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), np.array([2 + 3j, 4 + 5j, 6 + 7j])) assert_raises_regexp( ValueError, "Complex data not supported", check_array, X) - X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), - np.array([2 + 3j, 4 + 5j, 6 + 7j])) - assert_raises_regexp( - ValueError, "Complex data not supported", check_array, X) - + # dataframe X = MockDataFrame( np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]])) assert_raises_regexp( ValueError, "Complex data not supported", check_array, X) + # sparse matrix X = sp.coo_matrix([[0, 1 + 2j], [0, 0]]) assert_raises_regexp( ValueError, "Complex data not supported", check_array, X)