Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] Raise exception on providing complex data to estimators #9551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _yield_non_meta_checks(name, estimator):
yield check_sample_weights_pandas_series
yield check_sample_weights_list
yield check_estimators_fit_returns_self
yield check_complex_data

# Check that all estimator yield informative messages when
# trained on empty datasets
Expand Down Expand Up @@ -458,6 +459,16 @@ def check_dtype_object(name, estimator_orig):
assert_raises_regex(TypeError, msg, estimator.fit, X, y)


def check_complex_data(name, estimator_orig):
# check that estimators raise an exception on providing complex data
X = np.random.sample(10) + 1j * np.random.sample(10)
X = X.reshape(-1, 1)
y = np.random.sample(10) + 1j * np.random.sample(10)
estimator = clone(estimator_orig)
assert_raises_regex(ValueError, "Complex data not supported",
estimator.fit, X, y)


@ignore_warnings
def check_dict_unchanged(name, estimator_orig):
# this estimator raises
Expand Down
40 changes: 40 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,46 @@ def test_check_array_min_samples_and_features_messages():
assert_array_equal(y, y_checked)


def test_check_array_complex_data_error():
# np array
X = np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]])
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)

# list of lists
X = [[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)

# tuple of tuples
X = ((1 + 2j, 3 + 4j, 5 + 7j), (2 + 3j, 4 + 5j, 6 + 7j))
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)

# list of np arrays
X = [np.array([1 + 2j, 3 + 4j, 5 + 7j]),
np.array([2 + 3j, 4 + 5j, 6 + 7j])]
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)

# tuple of np arrays
X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]),
np.array([2 + 3j, 4 + 5j, 6 + 7j]))
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)

# dataframe
X = MockDataFrame(
np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]))
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)

# sparse matrix
X = sp.coo_matrix([[0, 1 + 2j], [0, 0]])
assert_raises_regexp(
ValueError, "Complex data not supported", check_array, X)


def test_has_fit_parameter():
assert_false(has_fit_parameter(KNeighborsClassifier, "sample_weight"))
assert_true(has_fit_parameter(RandomForestRegressor, "sample_weight"))
Expand Down
28 changes: 27 additions & 1 deletion sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import scipy.sparse as sp
from numpy.core.numeric import ComplexWarning

from ..externals import six
from ..utils.fixes import signature
Expand Down Expand Up @@ -276,6 +277,13 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
return spmatrix


def _ensure_no_complex_data(array):
if hasattr(array, 'dtype') and array.dtype is not None \
and hasattr(array.dtype, 'kind') and array.dtype.kind == "c":
raise ValueError("Complex data not supported\n"
"{}\n".format(array))


def check_array(array, accept_sparse=False, dtype="numeric", order=None,
copy=False, force_all_finite=True, ensure_2d=True,
allow_nd=False, ensure_min_samples=1, ensure_min_features=1,
Expand Down Expand Up @@ -396,10 +404,28 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
context = " by %s" % estimator_name if estimator is not None else ""

if sp.issparse(array):
_ensure_no_complex_data(array)
array = _ensure_sparse_format(array, accept_sparse, dtype, copy,
force_all_finite)
else:
array = np.array(array, dtype=dtype, order=order, copy=copy)
# If np.array(..) gives ComplexWarning, then we convert the warning
# to an error. This is needed because specifying a non complex
# dtype to the function converts complex to real dtype,
# thereby passing the test made in the lines following the scope
# of warnings context manager.
with warnings.catch_warnings():
try:
warnings.simplefilter('error', ComplexWarning)
array = np.array(array, dtype=dtype, order=order, copy=copy)
except ComplexWarning:
raise ValueError("Complex data not supported\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this branch is covered, but I don't see where. We're never passing dtype right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never passing dtype to check_array? Even if we don't pass it to check_array, the value of dtype changes within this function. For example lines 381-386. The ComplexWarning comes only if we are setting dtype to some one of the real types. However, if dtype is "None", no conversion takes place and no warning is produced, so we need to check that case again in subsequentl lines

Copy link
Contributor Author

@pravarmahajan pravarmahajan Aug 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example where dtype is explicitly passed to the function is that of svm decision function

"{}\n".format(array))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, at this point the array has been converted to float or int and the error message will be confusing. We need to keep a reference to the original input with complex values to report it in this error message (and del original_array otherwise to let the gabage collector free the memory as soon as possible otherwise).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe:

        with warnings.catch_warnings():
            try:
                warnings.simplefilter('error', ComplexWarning)
                new_array = np.array(array, dtype=dtype, order=order, copy=copy)
            except ComplexWarning:
                raise ValueError("Complex data not supported\n"
                                 "{}\n".format(array))
            array = new_array
            del new_array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum sorry my previous comments are wrong. Your code is fine.


# It is possible that the np.array(..) gave no warning. This happens
# when no dtype conversion happend, for example dtype = None. The
# result is that np.array(..) produces an array of complex dtype
# and we need to catch and raise exception for such cases.
_ensure_no_complex_data(array)

if ensure_2d:
if array.ndim == 1:
Expand Down