Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] BUG Returns only public estimators in all_estimators #15380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def _tested_estimators():
for name, Estimator in all_estimators():
if issubclass(Estimator, BiclusterMixin):
continue
if name.startswith("_"):
continue
try:
estimator = _construct_instance(Estimator)
except SkipTest:
Expand Down
40 changes: 25 additions & 15 deletions sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import pkgutil
import inspect
from importlib import import_module
from operator import itemgetter
from collections.abc import Sequence
from contextlib import contextmanager
Expand All @@ -12,6 +13,7 @@
import platform
import struct
import timeit
from pathlib import Path

import warnings
import numpy as np
Expand Down Expand Up @@ -1155,7 +1157,6 @@ def all_estimators(include_meta_estimators=None,
and ``class`` is the actuall type of the class.
"""
# lazy import to avoid circular imports from sklearn.base
import sklearn
from ._testing import ignore_warnings
from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin,
TransformerMixin, ClusterMixin)
Expand Down Expand Up @@ -1183,20 +1184,29 @@ def is_abstract(c):
DeprecationWarning)

all_classes = []
# get parent folder
path = sklearn.__path__
for importer, modname, ispkg in pkgutil.walk_packages(
path=path, prefix='sklearn.', onerror=lambda x: None):
if ".tests." in modname or "externals" in modname:
continue
if IS_PYPY and ('_svmlight_format' in modname or
'feature_extraction._hashing' in modname):
continue
# Ignore deprecation warnings triggered at import time.
with ignore_warnings(category=FutureWarning):
module = __import__(modname, fromlist="dummy")
classes = inspect.getmembers(module, inspect.isclass)
all_classes.extend(classes)
modules_to_ignore = {"tests", "externals", "setup", "conftest"}
root = str(Path(__file__).parent.parent) # sklearn package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for importer, modname, ispkg in pkgutil.walk_packages(
path=[root], prefix='sklearn.'):
mod_parts = modname.split(".")
if (any(part in modules_to_ignore for part in mod_parts)
or '._' in modname):
continue
module = import_module(modname)
classes = inspect.getmembers(module, inspect.isclass)
classes = [(name, est_cls) for name, est_cls in classes
if not name.startswith("_")]

# TODO: Remove when FeatureHasher is implemented in PYPY
# Skips FeatureHasher for PYPY
if IS_PYPY and 'feature_extraction' in modname:
classes = [(name, est_cls) for name, est_cls in classes
if name == "FeatureHasher"]

all_classes.extend(classes)

all_classes = set(all_classes)

Expand Down
9 changes: 9 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils.validation import check_X_y, check_array
from sklearn.utils import all_estimators


class CorrectNotFittedError(ValueError):
Expand Down Expand Up @@ -572,6 +573,14 @@ def test_check_class_weight_balanced_linear_classifier():
BadBalancedWeightsClassifier)


def test_all_estimators_all_public():
# all_estimator should not fail when pytest is not installed and return
# only public estimators
estimators = all_estimators()
for est in estimators:
assert not est.__class__.__name__.startswith("_")


if __name__ == '__main__':
# This module is run as a script to check that we have no dependency on
# pytest for estimator checks.
Expand Down