diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 221dd52834c90..9fc3075c5fe28 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -78,8 +78,6 @@ def _tested_estimators(): for name, Estimator in all_estimators(): if issubclass(Estimator, BiclusterMixin): continue - if name.startswith("_"): - continue try: estimator = _construct_instance(Estimator) except SkipTest: diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 4d4ef606341ca..c20ea5ab11d31 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -3,6 +3,7 @@ """ import pkgutil import inspect +from importlib import import_module from operator import itemgetter from collections.abc import Sequence from contextlib import contextmanager @@ -12,6 +13,7 @@ import platform import struct import timeit +from pathlib import Path import warnings import numpy as np @@ -1155,7 +1157,6 @@ def all_estimators(include_meta_estimators=None, and ``class`` is the actuall type of the class. """ # lazy import to avoid circular imports from sklearn.base - import sklearn from ._testing import ignore_warnings from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin, ClusterMixin) @@ -1183,20 +1184,29 @@ def is_abstract(c): DeprecationWarning) all_classes = [] - # get parent folder - path = sklearn.__path__ - for importer, modname, ispkg in pkgutil.walk_packages( - path=path, prefix='sklearn.', onerror=lambda x: None): - if ".tests." in modname or "externals" in modname: - continue - if IS_PYPY and ('_svmlight_format' in modname or - 'feature_extraction._hashing' in modname): - continue - # Ignore deprecation warnings triggered at import time. - with ignore_warnings(category=FutureWarning): - module = __import__(modname, fromlist="dummy") - classes = inspect.getmembers(module, inspect.isclass) - all_classes.extend(classes) + modules_to_ignore = {"tests", "externals", "setup", "conftest"} + root = str(Path(__file__).parent.parent) # sklearn package + # Ignore deprecation warnings triggered at import time and from walking + # packages + with ignore_warnings(category=FutureWarning): + for importer, modname, ispkg in pkgutil.walk_packages( + path=[root], prefix='sklearn.'): + mod_parts = modname.split(".") + if (any(part in modules_to_ignore for part in mod_parts) + or '._' in modname): + continue + module = import_module(modname) + classes = inspect.getmembers(module, inspect.isclass) + classes = [(name, est_cls) for name, est_cls in classes + if not name.startswith("_")] + + # TODO: Remove when FeatureHasher is implemented in PYPY + # Skips FeatureHasher for PYPY + if IS_PYPY and 'feature_extraction' in modname: + classes = [(name, est_cls) for name, est_cls in classes + if name == "FeatureHasher"] + + all_classes.extend(classes) all_classes = set(all_classes) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index f0c014829483f..15b423d6e0ce8 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -34,6 +34,7 @@ from sklearn.neighbors import KNeighborsRegressor from sklearn.tree import DecisionTreeClassifier from sklearn.utils.validation import check_X_y, check_array +from sklearn.utils import all_estimators class CorrectNotFittedError(ValueError): @@ -572,6 +573,14 @@ def test_check_class_weight_balanced_linear_classifier(): BadBalancedWeightsClassifier) +def test_all_estimators_all_public(): + # all_estimator should not fail when pytest is not installed and return + # only public estimators + estimators = all_estimators() + for est in estimators: + assert not est.__class__.__name__.startswith("_") + + if __name__ == '__main__': # This module is run as a script to check that we have no dependency on # pytest for estimator checks.