Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 54d91d3

Browse files
authored
FIX Fixes check_array for pd.NA in a series (#25080)
1 parent c047241 commit 54d91d3

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,10 @@ Changelog
713713
- |Fix| :func:`utils.estimator_checks.check_estimator` now takes into account
714714
the `requires_positive_X` tag correctly. :pr:`24667` by `Thomas Fan`_.
715715

716+
- |Fix| :func:`utils.check_array` now supports Pandas Series with `pd.NA`
717+
by raising a better error message or returning a compatible `ndarray`.
718+
:pr:`25080` by `Thomas Fan`_.
719+
716720
- |API| The extra keyword parameters of :func:`utils.extmath.density` are deprecated
717721
and will be removed in 1.4.
718722
:pr:`24523` by :user:`Mia Bajic <clytaemnestra>`.

sklearn/utils/tests/test_validation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,27 @@ def test_check_array_pandas_na_support(pd_dtype, dtype, expected_dtype):
447447
check_array(X, force_all_finite=True)
448448

449449

450+
def test_check_array_panadas_na_support_series():
451+
"""Check check_array is correct with pd.NA in a series."""
452+
pd = pytest.importorskip("pandas")
453+
454+
X_int64 = pd.Series([1, 2, pd.NA], dtype="Int64")
455+
456+
msg = "Input contains NaN"
457+
with pytest.raises(ValueError, match=msg):
458+
check_array(X_int64, force_all_finite=True, ensure_2d=False)
459+
460+
X_out = check_array(X_int64, force_all_finite=False, ensure_2d=False)
461+
assert_allclose(X_out, [1, 2, np.nan])
462+
assert X_out.dtype == np.float64
463+
464+
X_out = check_array(
465+
X_int64, force_all_finite=False, ensure_2d=False, dtype=np.float32
466+
)
467+
assert_allclose(X_out, [1, 2, np.nan])
468+
assert X_out.dtype == np.float32
469+
470+
450471
def test_check_array_pandas_dtype_casting():
451472
# test that data-frames with homogeneous dtype are not upcast
452473
pd = pytest.importorskip("pandas")

sklearn/utils/validation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,13 @@ def check_array(
777777
if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
778778
dtype_orig = np.result_type(*dtypes_orig)
779779

780+
elif hasattr(array, "iloc") and hasattr(array, "dtype"):
781+
# array is a pandas series
782+
pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
783+
if pandas_requires_conversion:
784+
# Set to None, to convert to a np.dtype that works with array.dtype
785+
dtype_orig = None
786+
780787
if dtype_numeric:
781788
if dtype_orig is not None and dtype_orig.kind == "O":
782789
# if input is object, convert to float.

0 commit comments

Comments
 (0)