diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 0dbf18e46bb0c..13d86ae41f375 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -22,8 +22,6 @@ from sklearn.utils.testing import ignore_warnings import sklearn -from sklearn.base import (ClassifierMixin, RegressorMixin, - TransformerMixin, ClusterMixin) from sklearn.preprocessing import StandardScaler from sklearn.datasets import make_classification @@ -86,9 +84,7 @@ def test_all_estimators(): def test_estimators_sparse_data(): # All estimators should either deal with sparse data or raise an # exception with type TypeError and an intelligible error message - estimators = all_estimators() - estimators = [(name, Estimator) for name, Estimator in estimators - if issubclass(Estimator, (ClassifierMixin, RegressorMixin))] + estimators = all_estimators(type_filter=['classifier', 'regressor']) for name, Estimator in estimators: yield check_regressors_classifiers_sparse_data, name, Estimator @@ -113,12 +109,8 @@ def test_transformers(): def test_estimators_nan_inf(): # Test that all estimators check their input for NaN's and infs - estimators = all_estimators() - estimators = [(name, E) for name, E in estimators - if (issubclass(E, ClassifierMixin) or - issubclass(E, RegressorMixin) or - issubclass(E, TransformerMixin) or - issubclass(E, ClusterMixin))] + estimators = all_estimators(type_filter=['classifier', 'regressor', + 'transformer', 'cluster']) for name, Estimator in estimators: if name not in CROSS_DECOMPOSITION + ['Imputer']: yield check_estimators_nan_inf, name, Estimator diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index ee0cc4b65666b..a25dfd47a30d0 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -74,7 +74,7 @@ def assert_not_in(x, container): except ImportError: # for Py 2.6 def assert_raises_regex(expected_exception, expected_regexp, - callable_obj=None, *args, **kwargs): + callable_obj=None, *args, **kwargs): """Helper function to check for message patterns in exceptions""" not_raised = False @@ -157,7 +157,7 @@ def assert_warns(warning_class, func, *args, **kw): if hasattr(np, 'VisibleDeprecationWarning'): # Filter out numpy-specific warnings in numpy >= 1.9 w = [e for e in w - if not e.category is np.VisibleDeprecationWarning] + if e.category is not np.VisibleDeprecationWarning] # Verify some things if not len(w) > 0: @@ -227,7 +227,7 @@ def assert_warns_message(warning_class, message, func, *args, **kw): if not check_in_message(msg): raise AssertionError("The message received ('%s') for <%s> is " "not the one you expected ('%s')" - % (msg, func.__name__, message + % (msg, func.__name__, message )) return result @@ -246,7 +246,7 @@ def assert_no_warnings(func, *args, **kw): if hasattr(np, 'VisibleDeprecationWarning'): # Filter out numpy-specific warnings in numpy >= 1.9 w = [e for e in w - if not e.category is np.VisibleDeprecationWarning] + if e.category is not np.VisibleDeprecationWarning] if len(w) > 0: raise AssertionError("Got warnings when calling %s: %s" @@ -510,11 +510,12 @@ def all_estimators(include_meta_estimators=False, include_other=False, include_dont_test : boolean, default=False Whether to include "special" label estimator or test processors. - type_filter : string or None, default=None + type_filter : string, list of string, or None, default=None Which kind of estimators should be returned. If None, no filter is applied and all estimators are returned. Possible values are 'classifier', 'regressor', 'cluster' and 'transformer' to get - estimators only of these specific types. + estimators only of these specific types, or a list of these to + get the estimators that fit at least one of the types. Returns ------- @@ -556,26 +557,29 @@ def is_abstract(c): # possibly get rid of meta estimators if not include_meta_estimators: estimators = [c for c in estimators if not c[0] in META_ESTIMATORS] - - if type_filter == 'classifier': - estimators = [est for est in estimators - if issubclass(est[1], ClassifierMixin)] - elif type_filter == 'regressor': - estimators = [est for est in estimators - if issubclass(est[1], RegressorMixin)] - elif type_filter == 'transformer': - estimators = [est for est in estimators - if issubclass(est[1], TransformerMixin)] - elif type_filter == 'cluster': - estimators = [est for est in estimators - if issubclass(est[1], ClusterMixin)] - elif type_filter is not None: - raise ValueError("Parameter type_filter must be 'classifier', " - "'regressor', 'transformer', 'cluster' or None, got" - " %s." % repr(type_filter)) - - # We sort in order to have reproducible test failures - return sorted(estimators) + if type_filter is not None: + if not isinstance(type_filter, list): + type_filter = [type_filter] + else: + type_filter = list(type_filter) # copy + filtered_estimators = [] + filters = {'classifier': ClassifierMixin, + 'regressor': RegressorMixin, + 'transformer': TransformerMixin, + 'cluster': ClusterMixin} + for name, mixin in filters.items(): + if name in type_filter: + type_filter.remove(name) + filtered_estimators.extend([est for est in estimators + if issubclass(est[1], mixin)]) + estimators = filtered_estimators + if type_filter: + raise ValueError("Parameter type_filter must be 'classifier', " + "'regressor', 'transformer', 'cluster' or None, got" + " %s." % repr(type_filter)) + + # drop duplicates, sort for reproducibility + return sorted(set(estimators)) def set_random_state(estimator, random_state=0):