@@ -74,7 +74,7 @@ def assert_not_in(x, container):
74
74
except ImportError :
75
75
# for Py 2.6
76
76
def assert_raises_regex (expected_exception , expected_regexp ,
77
- callable_obj = None , * args , ** kwargs ):
77
+ callable_obj = None , * args , ** kwargs ):
78
78
"""Helper function to check for message patterns in exceptions"""
79
79
80
80
not_raised = False
@@ -157,7 +157,7 @@ def assert_warns(warning_class, func, *args, **kw):
157
157
if hasattr (np , 'VisibleDeprecationWarning' ):
158
158
# Filter out numpy-specific warnings in numpy >= 1.9
159
159
w = [e for e in w
160
- if not e .category is np .VisibleDeprecationWarning ]
160
+ if e .category is not np .VisibleDeprecationWarning ]
161
161
162
162
# Verify some things
163
163
if not len (w ) > 0 :
@@ -227,7 +227,7 @@ def assert_warns_message(warning_class, message, func, *args, **kw):
227
227
if not check_in_message (msg ):
228
228
raise AssertionError ("The message received ('%s') for <%s> is "
229
229
"not the one you expected ('%s')"
230
- % (msg , func .__name__ , message
230
+ % (msg , func .__name__ , message
231
231
))
232
232
return result
233
233
@@ -246,7 +246,7 @@ def assert_no_warnings(func, *args, **kw):
246
246
if hasattr (np , 'VisibleDeprecationWarning' ):
247
247
# Filter out numpy-specific warnings in numpy >= 1.9
248
248
w = [e for e in w
249
- if not e .category is np .VisibleDeprecationWarning ]
249
+ if e .category is not np .VisibleDeprecationWarning ]
250
250
251
251
if len (w ) > 0 :
252
252
raise AssertionError ("Got warnings when calling %s: %s"
@@ -510,11 +510,12 @@ def all_estimators(include_meta_estimators=False, include_other=False,
510
510
include_dont_test : boolean, default=False
511
511
Whether to include "special" label estimator or test processors.
512
512
513
- type_filter : string or None, default=None
513
+ type_filter : string, list of string, or None, default=None
514
514
Which kind of estimators should be returned. If None, no filter is
515
515
applied and all estimators are returned. Possible values are
516
516
'classifier', 'regressor', 'cluster' and 'transformer' to get
517
- estimators only of these specific types.
517
+ estimators only of these specific types, or a list of these to
518
+ get the estimators that fit at least one of the types.
518
519
519
520
Returns
520
521
-------
@@ -556,26 +557,29 @@ def is_abstract(c):
556
557
# possibly get rid of meta estimators
557
558
if not include_meta_estimators :
558
559
estimators = [c for c in estimators if not c [0 ] in META_ESTIMATORS ]
559
-
560
- if type_filter == 'classifier' :
561
- estimators = [est for est in estimators
562
- if issubclass (est [1 ], ClassifierMixin )]
563
- elif type_filter == 'regressor' :
564
- estimators = [est for est in estimators
565
- if issubclass (est [1 ], RegressorMixin )]
566
- elif type_filter == 'transformer' :
567
- estimators = [est for est in estimators
568
- if issubclass (est [1 ], TransformerMixin )]
569
- elif type_filter == 'cluster' :
570
- estimators = [est for est in estimators
571
- if issubclass (est [1 ], ClusterMixin )]
572
- elif type_filter is not None :
573
- raise ValueError ("Parameter type_filter must be 'classifier', "
574
- "'regressor', 'transformer', 'cluster' or None, got"
575
- " %s." % repr (type_filter ))
576
-
577
- # We sort in order to have reproducible test failures
578
- return sorted (estimators )
560
+ if type_filter is not None :
561
+ if not isinstance (type_filter , list ):
562
+ type_filter = [type_filter ]
563
+ else :
564
+ type_filter = list (type_filter ) # copy
565
+ filtered_estimators = []
566
+ filters = {'classifier' : ClassifierMixin ,
567
+ 'regressor' : RegressorMixin ,
568
+ 'transformer' : TransformerMixin ,
569
+ 'cluster' : ClusterMixin }
570
+ for name , mixin in filters .items ():
571
+ if name in type_filter :
572
+ type_filter .remove (name )
573
+ filtered_estimators .extend ([est for est in estimators
574
+ if issubclass (est [1 ], mixin )])
575
+ estimators = filtered_estimators
576
+ if type_filter :
577
+ raise ValueError ("Parameter type_filter must be 'classifier', "
578
+ "'regressor', 'transformer', 'cluster' or None, got"
579
+ " %s." % repr (type_filter ))
580
+
581
+ # drop duplicates, sort for reproducibility
582
+ return sorted (set (estimators ))
579
583
580
584
581
585
def set_random_state (estimator , random_state = 0 ):
0 commit comments