diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 11a22fa425c6d..752c865519e2e 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -551,6 +551,10 @@ Changelog and sparse matrix. :pr:`14538` by :user:`Jérémie du Boisberranger `. +- |Fix| :func:`utils.check_array` is now raising an error instead of casting + NaN to integer. + :pr:`14872` by `Roman Yurchak`_. + :mod:`sklearn.metrics` .................................. diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 0f7ffe9a3e4f0..d5c0aa444a8e2 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -202,6 +202,26 @@ def test_check_array_force_all_finite_object(): check_array(X, dtype=None, force_all_finite=True) +@pytest.mark.parametrize( + "X, err_msg", + [(np.array([[1, np.nan]]), + "Input contains NaN, infinity or a value too large for.*int"), + (np.array([[1, np.nan]]), + "Input contains NaN, infinity or a value too large for.*int"), + (np.array([[1, np.inf]]), + "Input contains NaN, infinity or a value too large for.*int"), + (np.array([[1, np.nan]], dtype=np.object), + "cannot convert float NaN to integer")] +) +@pytest.mark.parametrize("force_all_finite", [True, False]) +def test_check_array_force_all_finite_object_unsafe_casting( + X, err_msg, force_all_finite): + # casting a float array containing NaN or inf to int dtype should + # raise an error irrespective of the force_all_finite parameter. + with pytest.raises(ValueError, match=err_msg): + check_array(X, dtype=np.int, force_all_finite=force_all_finite) + + @ignore_warnings def test_check_array(): # accept_sparse == False diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 465acf48e8293..5da8b6f2bed64 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -32,7 +32,7 @@ warnings.simplefilter('ignore', NonBLASDotWarning) -def _assert_all_finite(X, allow_nan=False): +def _assert_all_finite(X, allow_nan=False, msg_dtype=None): """Like assert_all_finite, but only for ndarray.""" # validation is also imported in extmath from .extmath import _safe_accumulator_op @@ -52,7 +52,11 @@ def _assert_all_finite(X, allow_nan=False): if (allow_nan and np.isinf(X).any() or not allow_nan and not np.isfinite(X).all()): type_err = 'infinity' if allow_nan else 'NaN, infinity' - raise ValueError(msg_err.format(type_err, X.dtype)) + raise ValueError( + msg_err.format + (type_err, + msg_dtype if msg_dtype is not None else X.dtype) + ) # for object dtype data, we only check for NaNs (GH-13254) elif X.dtype == np.dtype('object') and not allow_nan: if _object_dtype_isnan(X).any(): @@ -494,7 +498,17 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True, with warnings.catch_warnings(): try: warnings.simplefilter('error', ComplexWarning) - array = np.asarray(array, dtype=dtype, order=order) + if dtype is not None and np.dtype(dtype).kind in 'iu': + # Conversion float -> int should not contain NaN or + # inf (numpy#14412). We cannot use casting='safe' because + # then conversion float -> int would be disallowed. + array = np.asarray(array, order=order) + if array.dtype.kind == 'f': + _assert_all_finite(array, allow_nan=False, + msg_dtype=dtype) + array = array.astype(dtype, casting="unsafe", copy=False) + else: + array = np.asarray(array, order=order, dtype=dtype) except ComplexWarning: raise ValueError("Complex data not supported\n" "{}\n".format(array))