@@ -463,11 +463,14 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
463
463
>>> mlb.fit_transform([(1, 2), (3,)])
464
464
array([[1, 1, 0],
465
465
[0, 0, 1]])
466
- >>> mlb.classes_ # doctest: +ELLIPSIS
467
- array([1, 2, 3]...)
468
- >>> mlb.fit_transform([set([1, 2]), set([3])])
469
- array([[1, 1, 0],
470
- [0, 0, 1]])
466
+ >>> mlb.classes_
467
+ array([1, 2, 3])
468
+
469
+ >>> mlb.fit_transform([set(['sci-fi', 'thriller']), set(['comedy'])])
470
+ array([[0, 1, 1],
471
+ [1, 0, 0]])
472
+ >>> mlb.classes_
473
+ array(['comedy', 'sci-fi', 'thriller'], dtype=object)
471
474
"""
472
475
def __init__ (self , classes = None ):
473
476
self .classes = classes
@@ -490,7 +493,8 @@ def fit(self, y):
490
493
classes = sorted (set (itertools .chain .from_iterable (y )))
491
494
else :
492
495
classes = self .classes
493
- self .classes_ = np .empty (len (classes ), dtype = object )
496
+ dtype = np .int if all (isinstance (c , int ) for c in classes ) else object
497
+ self .classes_ = np .empty (len (classes ), dtype = dtype )
494
498
self .classes_ [:] = classes
495
499
return self
496
500
@@ -520,8 +524,10 @@ def fit_transform(self, y):
520
524
521
525
# sort classes and reorder columns
522
526
tmp = sorted (class_mapping , key = class_mapping .get )
527
+
523
528
# (make safe for tuples)
524
- class_mapping = np .empty (len (tmp ), dtype = object )
529
+ dtype = np .int if all (isinstance (c , int ) for c in tmp ) else object
530
+ class_mapping = np .empty (len (tmp ), dtype = dtype )
525
531
class_mapping [:] = tmp
526
532
self .classes_ , inverse = np .unique (class_mapping , return_inverse = True )
527
533
yt .indices = np .take (inverse , yt .indices )
0 commit comments