Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 65dfab0

Browse files
authored
FIX Fixes pandas extension arrays in check_array (#25813)
1 parent e75d8a6 commit 65dfab0

File tree

4 files changed

+40
-3
lines changed

4 files changed

+40
-3
lines changed

doc/whats_new/v1.3.rst

+3
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ Changelog
407407
:pr:`25733` by :user:`Brigitta Sipőcz <bsipocz>` and
408408
:user:`Jérémie du Boisberranger <jeremiedbb>`.
409409

410+
- |FIX| Fixes :func:`utils.validation.check_array` to properly convert pandas
411+
extension arrays. :pr:`25813` by `Thomas Fan`_.
412+
410413
:mod:`sklearn.semi_supervised`
411414
..............................
412415

sklearn/preprocessing/tests/test_label.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,19 @@ def test_label_binarizer_set_label_encoding():
118118

119119

120120
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
121-
def test_label_binarizer_pandas_nullable(dtype):
121+
@pytest.mark.parametrize("unique_first", [True, False])
122+
def test_label_binarizer_pandas_nullable(dtype, unique_first):
122123
"""Checks that LabelBinarizer works with pandas nullable dtypes.
123124
124125
Non-regression test for gh-25637.
125126
"""
126127
pd = pytest.importorskip("pandas")
127-
from sklearn.preprocessing import LabelBinarizer
128128

129129
y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
130+
if unique_first:
131+
# Calling unique creates a pandas array which has a different interface
132+
# compared to a pandas Series. Specifically, pandas arrays do not have "iloc".
133+
y_true = y_true.unique()
130134
lb = LabelBinarizer().fit(y_true)
131135
y_out = lb.transform([1, 0])
132136

sklearn/utils/tests/test_validation.py

+19
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,25 @@ def test_boolean_series_remains_boolean():
17621762
assert_array_equal(res, expected)
17631763

17641764

1765+
@pytest.mark.parametrize("input_values", [[0, 1, 0, 1, 0, np.nan], [0, 1, 0, 1, 0, 1]])
1766+
def test_pandas_array_returns_ndarray(input_values):
1767+
"""Check pandas array with extensions dtypes returns a numeric ndarray.
1768+
1769+
Non-regression test for gh-25637.
1770+
"""
1771+
pd = importorskip("pandas")
1772+
input_series = pd.array(input_values, dtype="Int32")
1773+
result = check_array(
1774+
input_series,
1775+
dtype=None,
1776+
ensure_2d=False,
1777+
allow_nd=False,
1778+
force_all_finite=False,
1779+
)
1780+
assert np.issubdtype(result.dtype.kind, np.floating)
1781+
assert_allclose(result, input_values)
1782+
1783+
17651784
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
17661785
def test_check_array_array_api_has_non_finite(array_namespace):
17671786
"""Checks that Array API arrays checks non-finite correctly."""

sklearn/utils/validation.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,15 @@ def _pandas_dtype_needs_early_conversion(pd_dtype):
626626
return False
627627

628628

629+
def _is_extension_array_dtype(array):
630+
try:
631+
from pandas.api.types import is_extension_array_dtype
632+
633+
return is_extension_array_dtype(array)
634+
except ImportError:
635+
return False
636+
637+
629638
def check_array(
630639
array,
631640
accept_sparse=False,
@@ -777,7 +786,9 @@ def check_array(
777786
if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
778787
dtype_orig = np.result_type(*dtypes_orig)
779788

780-
elif hasattr(array, "iloc") and hasattr(array, "dtype"):
789+
elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
790+
array, "dtype"
791+
):
781792
# array is a pandas series
782793
pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
783794
if isinstance(array.dtype, np.dtype):

0 commit comments

Comments
 (0)