diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index e7cd4c1a0e19d..ce579874f886c 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -244,6 +244,9 @@ Changelog during `transform` with no prior call to `fit` or `fit_transform`. :pr:`25190` by :user:`Vincent Maladière `. +- |Enhancement| :func:`utils.multiclass.type_of_target` can identify pandas + nullable data types as classification targets. :pr:`25638` by `Thomas Fan`_. + :mod:`sklearn.semi_supervised` .............................. diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 222f445555200..255c83c852325 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1079,6 +1079,24 @@ def test_confusion_matrix_dtype(): assert cm[1, 1] == -2 +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_confusion_matrix_pandas_nullable(dtype): + """Checks that confusion_matrix works with pandas nullable dtypes. + + Non-regression test for gh-25635. + """ + pd = pytest.importorskip("pandas") + + y_ndarray = np.array([1, 0, 0, 1, 0, 1, 1, 0, 1]) + y_true = pd.Series(y_ndarray, dtype=dtype) + y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64") + + output = confusion_matrix(y_true, y_predicted) + expected_output = confusion_matrix(y_ndarray, y_predicted) + + assert_array_equal(output, expected_output) + + def test_classification_report_multiclass(): # Test performance report iris = datasets.load_iris() diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index bd5ba6d2da9ee..b90218a97016b 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -117,6 +117,22 @@ def test_label_binarizer_set_label_encoding(): assert_array_equal(lb.inverse_transform(got), inp) +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_label_binarizer_pandas_nullable(dtype): + """Checks that LabelBinarizer works with pandas nullable dtypes. + + Non-regression test for gh-25637. + """ + pd = pytest.importorskip("pandas") + from sklearn.preprocessing import LabelBinarizer + + y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype) + lb = LabelBinarizer().fit(y_true) + y_out = lb.transform([1, 0]) + + assert_array_equal(y_out, [[1], [0]]) + + @ignore_warnings def test_label_binarizer_errors(): # Check that invalid arguments yield ValueError diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 5eaef2fde87e4..f14b981f9b83a 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -155,14 +155,25 @@ def is_multilabel(y): if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api: # DeprecationWarning will be replaced by ValueError, see NEP 34 # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html + check_y_kwargs = dict( + accept_sparse=True, + allow_nd=True, + force_all_finite=False, + ensure_2d=False, + ensure_min_samples=0, + ensure_min_features=0, + ) with warnings.catch_warnings(): warnings.simplefilter("error", np.VisibleDeprecationWarning) try: - y = xp.asarray(y) - except (np.VisibleDeprecationWarning, ValueError): + y = check_array(y, dtype=None, **check_y_kwargs) + except (np.VisibleDeprecationWarning, ValueError) as e: + if str(e).startswith("Complex data not supported"): + raise + # dtype=object should be provided explicitly for ragged arrays, # see NEP 34 - y = xp.asarray(y, dtype=object) + y = check_array(y, dtype=object, **check_y_kwargs) if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1): return False @@ -302,15 +313,27 @@ def type_of_target(y, input_name=""): # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html # We therefore catch both deprecation (NumPy < 1.24) warning and # value error (NumPy >= 1.24). + check_y_kwargs = dict( + accept_sparse=True, + allow_nd=True, + force_all_finite=False, + ensure_2d=False, + ensure_min_samples=0, + ensure_min_features=0, + ) + with warnings.catch_warnings(): warnings.simplefilter("error", np.VisibleDeprecationWarning) if not issparse(y): try: - y = xp.asarray(y) - except (np.VisibleDeprecationWarning, ValueError): + y = check_array(y, dtype=None, **check_y_kwargs) + except (np.VisibleDeprecationWarning, ValueError) as e: + if str(e).startswith("Complex data not supported"): + raise + # dtype=object should be provided explicitly for ragged arrays, # see NEP 34 - y = xp.asarray(y, dtype=object) + y = check_array(y, dtype=object, **check_y_kwargs) # The old sequence of sequences format try: diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py index cf5858d0f52f9..731edbefc3925 100644 --- a/sklearn/utils/tests/test_multiclass.py +++ b/sklearn/utils/tests/test_multiclass.py @@ -346,6 +346,42 @@ def test_type_of_target_pandas_sparse(): type_of_target(y) +def test_type_of_target_pandas_nullable(): + """Check that type_of_target works with pandas nullable dtypes.""" + pd = pytest.importorskip("pandas") + + for dtype in ["Int32", "Float32"]: + y_true = pd.Series([1, 0, 2, 3, 4], dtype=dtype) + assert type_of_target(y_true) == "multiclass" + + y_true = pd.Series([1, 0, 1, 0], dtype=dtype) + assert type_of_target(y_true) == "binary" + + y_true = pd.DataFrame([[1.4, 3.1], [3.1, 1.4]], dtype="Float32") + assert type_of_target(y_true) == "continuous-multioutput" + + y_true = pd.DataFrame([[0, 1], [1, 1]], dtype="Int32") + assert type_of_target(y_true) == "multilabel-indicator" + + y_true = pd.DataFrame([[1, 2], [3, 1]], dtype="Int32") + assert type_of_target(y_true) == "multiclass-multioutput" + + +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_unique_labels_pandas_nullable(dtype): + """Checks that unique_labels work with pandas nullable dtypes. + + Non-regression test for gh-25634. + """ + pd = pytest.importorskip("pandas") + + y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype) + y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64") + + labels = unique_labels(y_true, y_predicted) + assert_array_equal(labels, [0, 1]) + + def test_class_distribution(): y = np.array( [