Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 165f949

Browse files
committed
Merge pull request #3934 from amueller/common_test_slight_cleanup
MRG Allow list of strings for type_filter in all_estimators.
2 parents ea1d134 + f35116a commit 165f949

File tree

2 files changed

+33
-37
lines changed

2 files changed

+33
-37
lines changed

sklearn/tests/test_common.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from sklearn.utils.testing import ignore_warnings
2323

2424
import sklearn
25-
from sklearn.base import (ClassifierMixin, RegressorMixin,
26-
TransformerMixin, ClusterMixin)
2725
from sklearn.preprocessing import StandardScaler
2826
from sklearn.datasets import make_classification
2927

@@ -86,9 +84,7 @@ def test_all_estimators():
8684
def test_estimators_sparse_data():
8785
# All estimators should either deal with sparse data or raise an
8886
# exception with type TypeError and an intelligible error message
89-
estimators = all_estimators()
90-
estimators = [(name, Estimator) for name, Estimator in estimators
91-
if issubclass(Estimator, (ClassifierMixin, RegressorMixin))]
87+
estimators = all_estimators(type_filter=['classifier', 'regressor'])
9288
for name, Estimator in estimators:
9389
yield check_regressors_classifiers_sparse_data, name, Estimator
9490

@@ -113,12 +109,8 @@ def test_transformers():
113109

114110
def test_estimators_nan_inf():
115111
# Test that all estimators check their input for NaN's and infs
116-
estimators = all_estimators()
117-
estimators = [(name, E) for name, E in estimators
118-
if (issubclass(E, ClassifierMixin) or
119-
issubclass(E, RegressorMixin) or
120-
issubclass(E, TransformerMixin) or
121-
issubclass(E, ClusterMixin))]
112+
estimators = all_estimators(type_filter=['classifier', 'regressor',
113+
'transformer', 'cluster'])
122114
for name, Estimator in estimators:
123115
if name not in CROSS_DECOMPOSITION + ['Imputer']:
124116
yield check_estimators_nan_inf, name, Estimator

sklearn/utils/testing.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def assert_not_in(x, container):
7474
except ImportError:
7575
# for Py 2.6
7676
def assert_raises_regex(expected_exception, expected_regexp,
77-
callable_obj=None, *args, **kwargs):
77+
callable_obj=None, *args, **kwargs):
7878
"""Helper function to check for message patterns in exceptions"""
7979

8080
not_raised = False
@@ -157,7 +157,7 @@ def assert_warns(warning_class, func, *args, **kw):
157157
if hasattr(np, 'VisibleDeprecationWarning'):
158158
# Filter out numpy-specific warnings in numpy >= 1.9
159159
w = [e for e in w
160-
if not e.category is np.VisibleDeprecationWarning]
160+
if e.category is not np.VisibleDeprecationWarning]
161161

162162
# Verify some things
163163
if not len(w) > 0:
@@ -227,7 +227,7 @@ def assert_warns_message(warning_class, message, func, *args, **kw):
227227
if not check_in_message(msg):
228228
raise AssertionError("The message received ('%s') for <%s> is "
229229
"not the one you expected ('%s')"
230-
% (msg, func.__name__, message
230+
% (msg, func.__name__, message
231231
))
232232
return result
233233

@@ -246,7 +246,7 @@ def assert_no_warnings(func, *args, **kw):
246246
if hasattr(np, 'VisibleDeprecationWarning'):
247247
# Filter out numpy-specific warnings in numpy >= 1.9
248248
w = [e for e in w
249-
if not e.category is np.VisibleDeprecationWarning]
249+
if e.category is not np.VisibleDeprecationWarning]
250250

251251
if len(w) > 0:
252252
raise AssertionError("Got warnings when calling %s: %s"
@@ -510,11 +510,12 @@ def all_estimators(include_meta_estimators=False, include_other=False,
510510
include_dont_test : boolean, default=False
511511
Whether to include "special" label estimator or test processors.
512512
513-
type_filter : string or None, default=None
513+
type_filter : string, list of string, or None, default=None
514514
Which kind of estimators should be returned. If None, no filter is
515515
applied and all estimators are returned. Possible values are
516516
'classifier', 'regressor', 'cluster' and 'transformer' to get
517-
estimators only of these specific types.
517+
estimators only of these specific types, or a list of these to
518+
get the estimators that fit at least one of the types.
518519
519520
Returns
520521
-------
@@ -556,26 +557,29 @@ def is_abstract(c):
556557
# possibly get rid of meta estimators
557558
if not include_meta_estimators:
558559
estimators = [c for c in estimators if not c[0] in META_ESTIMATORS]
559-
560-
if type_filter == 'classifier':
561-
estimators = [est for est in estimators
562-
if issubclass(est[1], ClassifierMixin)]
563-
elif type_filter == 'regressor':
564-
estimators = [est for est in estimators
565-
if issubclass(est[1], RegressorMixin)]
566-
elif type_filter == 'transformer':
567-
estimators = [est for est in estimators
568-
if issubclass(est[1], TransformerMixin)]
569-
elif type_filter == 'cluster':
570-
estimators = [est for est in estimators
571-
if issubclass(est[1], ClusterMixin)]
572-
elif type_filter is not None:
573-
raise ValueError("Parameter type_filter must be 'classifier', "
574-
"'regressor', 'transformer', 'cluster' or None, got"
575-
" %s." % repr(type_filter))
576-
577-
# We sort in order to have reproducible test failures
578-
return sorted(estimators)
560+
if type_filter is not None:
561+
if not isinstance(type_filter, list):
562+
type_filter = [type_filter]
563+
else:
564+
type_filter = list(type_filter) # copy
565+
filtered_estimators = []
566+
filters = {'classifier': ClassifierMixin,
567+
'regressor': RegressorMixin,
568+
'transformer': TransformerMixin,
569+
'cluster': ClusterMixin}
570+
for name, mixin in filters.items():
571+
if name in type_filter:
572+
type_filter.remove(name)
573+
filtered_estimators.extend([est for est in estimators
574+
if issubclass(est[1], mixin)])
575+
estimators = filtered_estimators
576+
if type_filter:
577+
raise ValueError("Parameter type_filter must be 'classifier', "
578+
"'regressor', 'transformer', 'cluster' or None, got"
579+
" %s." % repr(type_filter))
580+
581+
# drop duplicates, sort for reproducibility
582+
return sorted(set(estimators))
579583

580584

581585
def set_random_state(estimator, random_state=0):

0 commit comments

Comments
 (0)