Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9c4f023

Browse files
authored
FIX Improves nan support in LabelEncoder (#22629)
1 parent 53234c5 commit 9c4f023

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

doc/whats_new/v1.2.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ Changelog
539539
is now deprecated and will be removed in version 1.4. Use `sparse_output` instead.
540540
:pr:`24412` by :user:`Rushil Desai <rusdes>`.
541541

542+
- |Fix| :class:`preprocessing.LabelEncoder` correctly encodes NaNs in `transform`.
543+
:pr:`22629` by `Thomas Fan`_.
544+
542545
:mod:`sklearn.svm`
543546
..................
544547

@@ -560,6 +563,9 @@ Changelog
560563
deterministic SVD used by the randomized SVD algorithm.
561564
:pr:`20617` by :user:`Srinath Kailasa <skailasa>`
562565

566+
- |Enhancement| :func:`utils.validation.column_or_1d` now accepts a `dtype`
567+
parameter to specific `y`'s dtype. :pr:`22629` by `Thomas Fan`_.
568+
563569
- |FIX| :func:`utils.multiclass.type_of_target` now properly handles sparse matrices.
564570
:pr:`14862` by :user:`Léonard Binet <leonardbinet>`.
565571

sklearn/preprocessing/_label.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def transform(self, y):
131131
Labels as normalized encodings.
132132
"""
133133
check_is_fitted(self)
134-
y = column_or_1d(y, warn=True)
134+
y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)
135135
# transform of empty array is empty array
136136
if _num_samples(y) == 0:
137137
return np.array([])

sklearn/preprocessing/tests/test_label.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,3 +643,15 @@ def test_inverse_binarize_multiclass():
643643
csr_matrix([[0, 1, 0], [-1, 0, -1], [0, 0, 0]]), np.arange(3)
644644
)
645645
assert_array_equal(got, np.array([1, 1, 0]))
646+
647+
648+
def test_nan_label_encoder():
649+
"""Check that label encoder encodes nans in transform.
650+
651+
Non-regression test for #22628.
652+
"""
653+
le = LabelEncoder()
654+
le.fit(["a", "a", "b", np.nan])
655+
656+
y_trans = le.transform([np.nan])
657+
assert_array_equal(y_trans, [2])

sklearn/utils/validation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,14 +1140,19 @@ def _check_y(y, multi_output=False, y_numeric=False, estimator=None):
11401140
return y
11411141

11421142

1143-
def column_or_1d(y, *, warn=False):
1143+
def column_or_1d(y, *, dtype=None, warn=False):
11441144
"""Ravel column or 1d numpy array, else raises an error.
11451145
11461146
Parameters
11471147
----------
11481148
y : array-like
11491149
Input data.
11501150
1151+
dtype : data-type, default=None
1152+
Data type for `y`.
1153+
1154+
.. versionadded:: 1.2
1155+
11511156
warn : bool, default=False
11521157
To control display of warnings.
11531158
@@ -1162,7 +1167,7 @@ def column_or_1d(y, *, warn=False):
11621167
If `y` is not a 1D array or a 2D array with a single row or column.
11631168
"""
11641169
xp, _ = get_namespace(y)
1165-
y = xp.asarray(y)
1170+
y = xp.asarray(y, dtype=dtype)
11661171
shape = y.shape
11671172
if len(shape) == 1:
11681173
return _asarray_with_order(xp.reshape(y, -1), order="C", xp=xp)

0 commit comments

Comments
 (0)