Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c013e26

Browse files
committed
ENH use the np.int dtype to encode integer classes
1 parent 238def1 commit c013e26

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

sklearn/preprocessing/label.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,14 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
463463
>>> mlb.fit_transform([(1, 2), (3,)])
464464
array([[1, 1, 0],
465465
[0, 0, 1]])
466-
>>> mlb.classes_ # doctest: +ELLIPSIS
467-
array([1, 2, 3]...)
468-
>>> mlb.fit_transform([set([1, 2]), set([3])])
469-
array([[1, 1, 0],
470-
[0, 0, 1]])
466+
>>> mlb.classes_
467+
array([1, 2, 3])
468+
469+
>>> mlb.fit_transform([set(['sci-fi', 'thriller']), set(['comedy'])])
470+
array([[0, 1, 1],
471+
[1, 0, 0]])
472+
>>> mlb.classes_
473+
array(['comedy', 'sci-fi', 'thriller'], dtype=object)
471474
"""
472475
def __init__(self, classes=None):
473476
self.classes = classes
@@ -490,7 +493,8 @@ def fit(self, y):
490493
classes = sorted(set(itertools.chain.from_iterable(y)))
491494
else:
492495
classes = self.classes
493-
self.classes_ = np.empty(len(classes), dtype=object)
496+
dtype = np.int if all(isinstance(c, int) for c in classes) else object
497+
self.classes_ = np.empty(len(classes), dtype=dtype)
494498
self.classes_[:] = classes
495499
return self
496500

@@ -520,8 +524,10 @@ def fit_transform(self, y):
520524

521525
# sort classes and reorder columns
522526
tmp = sorted(class_mapping, key=class_mapping.get)
527+
523528
# (make safe for tuples)
524-
class_mapping = np.empty(len(tmp), dtype=object)
529+
dtype = np.int if all(isinstance(c, int) for c in tmp) else object
530+
class_mapping = np.empty(len(tmp), dtype=dtype)
525531
class_mapping[:] = tmp
526532
self.classes_, inverse = np.unique(class_mapping, return_inverse=True)
527533
yt.indices = np.take(inverse, yt.indices)

0 commit comments

Comments
 (0)