diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c3b066e5e31be..fa8e24c588555 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -76,6 +76,7 @@ def _yield_non_meta_checks(name, estimator): yield check_sample_weights_pandas_series yield check_sample_weights_list yield check_estimators_fit_returns_self + yield check_complex_data # Check that all estimator yield informative messages when # trained on empty datasets @@ -458,6 +459,16 @@ def check_dtype_object(name, estimator_orig): assert_raises_regex(TypeError, msg, estimator.fit, X, y) +def check_complex_data(name, estimator_orig): + # check that estimators raise an exception on providing complex data + X = np.random.sample(10) + 1j * np.random.sample(10) + X = X.reshape(-1, 1) + y = np.random.sample(10) + 1j * np.random.sample(10) + estimator = clone(estimator_orig) + assert_raises_regex(ValueError, "Complex data not supported", + estimator.fit, X, y) + + @ignore_warnings def check_dict_unchanged(name, estimator_orig): # this estimator raises diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 1fe27f199ac63..cfdc03341e8bc 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -433,6 +433,46 @@ def test_check_array_min_samples_and_features_messages(): assert_array_equal(y, y_checked) +def test_check_array_complex_data_error(): + # np array + X = np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + # list of lists + X = [[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]] + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + # tuple of tuples + X = ((1 + 2j, 3 + 4j, 5 + 7j), (2 + 3j, 4 + 5j, 6 + 7j)) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + # list of np arrays + X = [np.array([1 + 2j, 3 + 4j, 5 + 7j]), + np.array([2 + 3j, 4 + 5j, 6 + 7j])] + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + # tuple of np arrays + X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), + np.array([2 + 3j, 4 + 5j, 6 + 7j])) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + # dataframe + X = MockDataFrame( + np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]])) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + # sparse matrix + X = sp.coo_matrix([[0, 1 + 2j], [0, 0]]) + assert_raises_regexp( + ValueError, "Complex data not supported", check_array, X) + + def test_has_fit_parameter(): assert_false(has_fit_parameter(KNeighborsClassifier, "sample_weight")) assert_true(has_fit_parameter(RandomForestRegressor, "sample_weight")) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 460f20673feaf..4af88a7ed9bb6 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -13,6 +13,7 @@ import numpy as np import scipy.sparse as sp +from numpy.core.numeric import ComplexWarning from ..externals import six from ..utils.fixes import signature @@ -276,6 +277,13 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, return spmatrix +def _ensure_no_complex_data(array): + if hasattr(array, 'dtype') and array.dtype is not None \ + and hasattr(array.dtype, 'kind') and array.dtype.kind == "c": + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) + + def check_array(array, accept_sparse=False, dtype="numeric", order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, ensure_min_samples=1, ensure_min_features=1, @@ -396,10 +404,28 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, context = " by %s" % estimator_name if estimator is not None else "" if sp.issparse(array): + _ensure_no_complex_data(array) array = _ensure_sparse_format(array, accept_sparse, dtype, copy, force_all_finite) else: - array = np.array(array, dtype=dtype, order=order, copy=copy) + # If np.array(..) gives ComplexWarning, then we convert the warning + # to an error. This is needed because specifying a non complex + # dtype to the function converts complex to real dtype, + # thereby passing the test made in the lines following the scope + # of warnings context manager. + with warnings.catch_warnings(): + try: + warnings.simplefilter('error', ComplexWarning) + array = np.array(array, dtype=dtype, order=order, copy=copy) + except ComplexWarning: + raise ValueError("Complex data not supported\n" + "{}\n".format(array)) + + # It is possible that the np.array(..) gave no warning. This happens + # when no dtype conversion happend, for example dtype = None. The + # result is that np.array(..) produces an array of complex dtype + # and we need to catch and raise exception for such cases. + _ensure_no_complex_data(array) if ensure_2d: if array.ndim == 1: