-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Estimator tags #8022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Estimator tags #8022
Changes from all commits
98d9aff
165727a
660bc44
74d10b6
d29ca95
b68c822
1ea2e28
ca84e37
72944e0
f5c5b7c
8ec5d7c
3e5194c
3868203
1c7d02f
4758abd
a62cd91
b857a05
62cfcc9
ee2c97b
e8efb8b
9aaae44
fcf5169
8a52e34
b90f0d5
f6e9b15
9502c6e
281a7c2
5aa2390
e36ea42
f871162
9194d73
e601a4b
a8648d5
7b7e152
1b23d88
c877e77
8d42707
923a946
aa9f6ba
74aa03d
ed0d91d
0966ee9
b944ee3
bd5ccb0
537daf9
fd717e8
283217a
b3281c0
b32a2ca
ab594c2
246d368
3c353e8
4591799
f368dd9
dedb873
928b3c8
d039962
2edf651
c1f7842
633f945
0d607eb
5c12cba
a52eff1
6749ff3
a57a253
e7cc0d7
779074a
0d08435
e5721be
12112ac
5866538
b926691
980a2dc
28b1dd1
2dce52c
9046dcb
b58c9d1
e054afd
095dd3f
b5092cc
8666465
bbfaf59
4dd732d
7ce1123
48bd931
b1171ed
c636b20
7eb6bed
7cb4505
c8b1f96
ca0767a
79e1c8f
9840f43
efe4614
8fede49
27743d4
02a93e8
764898e
7ef1c2b
57736d1
0691b71
3f74443
5a59d2f
cb74e53
49b48c9
7e5e0a1
b96a335
d660059
71a72a8
e2b8d63
a0c5eeb
5d91633
1ff8463
cce8954
ef97a81
46189b8
b151752
2cd6e1c
4c509e6
16f487b
dfc661a
c499b08
720e34c
22eee88
9eab395
a47e9f8
3b5762d
03e1716
ff37f01
5d73c1a
2157614
83744ef
e1f80d3
54bce7a
4e00dff
c04f361
91804f8
5df999c
e053cce
5aa313b
f547204
81b1c51
0617512
2e8d206
af2aaa6
c29dac4
61c5628
1dd02c0
500921e
afea648
16ba879
a8ea48c
860dd6b
48e6fca
f574be8
b217bb7
d09eb6f
cf3ded6
2d67c2f
42138a4
e13df63
17e5a9c
42fff09
9f34866
f68d5c0
d71f0c4
3b3ac3d
678b74f
e1d15b9
7851b7f
aeb3b36
b406af1
7e09f23
83f8883
259668c
e3b6459
e7bf51d
6930c8a
89c3050
5da0089
1b7725d
a79b82d
af8856f
d794c8b
f118b76
d1b67dc
7fe2cd4
20ca277
727267b
5da2c16
11f5e5c
18187a2
4263515
56f6903
18e2d66
173c126
049e3aa
2196a22
0a25ad3
1052f43
584d702
3666f63
8928ed4
1109981
8de0d04
b2c6b43
22501b8
aef7378
42aa99a
e10f20e
5e94012
e61b35d
4c1ed2d
281a7ef
d759329
4715e1b
873d916
d67df1c
83fa5f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1419,22 +1419,18 @@ advised to maintain notes on the `GitHub wiki | |
Specific models | ||
--------------- | ||
|
||
Classifiers should accept ``y`` (target) arguments to ``fit`` | ||
that are sequences (lists, arrays) of either strings or integers. | ||
They should not assume that the class labels | ||
are a contiguous range of integers; | ||
instead, they should store a list of classes | ||
in a ``classes_`` attribute or property. | ||
The order of class labels in this attribute | ||
should match the order in which ``predict_proba``, ``predict_log_proba`` | ||
and ``decision_function`` return their values. | ||
The easiest way to achieve this is to put:: | ||
Classifiers should accept ``y`` (target) arguments to ``fit`` that are | ||
sequences (lists, arrays) of either strings or integers. They should not | ||
assume that the class labels are a contiguous range of integers; instead, they | ||
should store a list of classes in a ``classes_`` attribute or property. The | ||
order of class labels in this attribute should match the order in which | ||
``predict_proba``, ``predict_log_proba`` and ``decision_function`` return their | ||
values. The easiest way to achieve this is to put:: | ||
|
||
self.classes_, y = np.unique(y, return_inverse=True) | ||
|
||
in ``fit``. | ||
This returns a new ``y`` that contains class indexes, rather than labels, | ||
in the range [0, ``n_classes``). | ||
in ``fit``. This returns a new ``y`` that contains class indexes, rather than | ||
labels, in the range [0, ``n_classes``). | ||
|
||
A classifier's ``predict`` method should return | ||
arrays containing class labels from ``classes_``. | ||
|
@@ -1445,14 +1441,89 @@ this can be achieved with:: | |
D = self.decision_function(X) | ||
return self.classes_[np.argmax(D, axis=1)] | ||
|
||
In linear models, coefficients are stored in an array called ``coef_``, | ||
and the independent term is stored in ``intercept_``. | ||
``sklearn.linear_model.base`` contains a few base classes and mixins | ||
that implement common linear model patterns. | ||
In linear models, coefficients are stored in an array called ``coef_``, and the | ||
independent term is stored in ``intercept_``. ``sklearn.linear_model.base`` | ||
contains a few base classes and mixins that implement common linear model | ||
patterns. | ||
|
||
The :mod:`sklearn.utils.multiclass` module contains useful functions | ||
for working with multiclass and multilabel problems. | ||
|
||
Estimator Tags | ||
-------------- | ||
.. warning:: | ||
|
||
The estimator tags are experimental and the API is subject to change. | ||
|
||
Scikit-learn introduced estimator tags in version 0.21. These are annotations | ||
of estimators that allow programmatic inspection of their capabilities, such as | ||
sparse matrix support, supported output types and supported methods. The | ||
estimator tags are a dictionary returned by the method ``_get_tags()``. These | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to note that these tags may be dependent on estimator parameters and even system architecture, and hence are a method on an instance, rather than a property of the class. You should probably also define the default implementation and |
||
tags are used by the common tests and the :func:`sklearn.utils.estomator_checks.check_estimator` function to | ||
decide what tests to run and what input data is appropriate. Tags can depends on | ||
estimator parameters or even system architecture and can in general only be | ||
determined at runtime. | ||
|
||
The default value of all tags except for ``X_types`` is ``False``. | ||
|
||
The current set of estimator tags are: | ||
|
||
non_deterministic | ||
whether the estimator is not deterministic given a fixed ``random_state`` | ||
|
||
requires_positive_data - unused for now | ||
whether the estimator requires positive X. | ||
|
||
no_validation | ||
whether the estimator skips input-validation. This is only meant for stateless and dummy transformers! | ||
|
||
multioutput - unused for now | ||
whether a regressor supports multi-target outputs or a classifier supports multi-class multi-output. | ||
|
||
multilabel | ||
whether the estimator supports multilabel output | ||
|
||
stateless | ||
whether the estimator needs access to data for fitting. Even though | ||
an estimator is stateless, it might still need a call to ``fit`` for initialization. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (should we deprecate the need for a call to fit for initialisation in stateless estimators?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we can, because "stateless" can still mean it depends on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure "stateless" is the right word then? We mean "data independent"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to have separate tags for data independent and no state at all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Dunno. What's the use case? If so, we could consider a ternary tag... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the main use-case here was that some estimators didn't complain if the number of features was different in fit and transform, possibly only AdditiveChi2Sampler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm AdditiveChi2Sampler requires calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened #12616 to follow up. I don't think there's a good reason for a ternary tag. Right now this is used for testing two things: checking that calling |
||
|
||
allow_nan | ||
whether the estimator supports data with missing values encoded as np.NaN | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There may be some subtlety to this. What if it supports NaN at transform but not at fit (with some parameters)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, or the other way around? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the other way around would be a case of interest. In #11635 we identified that you might have a feature selector that could not train on missing data (if only because the parameters weren't right) but there's no reason it shouldn't transform with missing data. |
||
|
||
poor_score | ||
whether the estimator fails to provide a "reasonable" test-set score, which | ||
currently for regression is an R2 of 0.5 on a subset of the boston housing | ||
dataset, and for classification an accuracy of 0.83 on | ||
``make_blobs(n_samples=300, random_state=0)``. These datasets and values | ||
are based on current estimators in sklearn and might be replaced by | ||
something more systematic. | ||
|
||
multioutput_only | ||
whether estimator supports only multi-output classification or regression. | ||
|
||
_skip_test | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the _ doesn't mean private, perhaps we can use something like ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it kinda means private in the sense that no-one should ever use it ;) |
||
whether to skip common tests entirely. Don't use this unless you have a *very good* reason. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. X_types is undocumented at present, and is mysterious... should it not be a series of boolean tags instead of a list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would require us to define a list of possible input types now and it would be harder to change in the future though, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is a set of boolean tags harder than a list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt it might be more natural to add new things to a set/list than add another boolean variable to a set/list of boolean variables. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dunno. A list is fine... and could have benefits if the objects in the list are not merely strings |
||
X_types | ||
Supported input types for X as list of strings. Tests are currently only run if '2darray' is contained | ||
in the list, signifying that the estimator takes continuous 2d numpy arrays as input. The default | ||
value is ['2darray']. Other possible types are ``'string'``, ``'sparse'``, | ||
``'categorical'``, ``dict``, ``'1dlabels'`` and ``'2dlabels'``. | ||
The goals is that in the future the supported input type will determine the | ||
data used during testsing, in particular for ``'string'``, ``'sparse'`` and | ||
``'categorical'`` data. For now, the test for sparse data do not make use | ||
of the ``'sparse'`` tag. | ||
|
||
|
||
In addition to the tags, estimators are also need to declare any non-optional | ||
parameters to ``__init__`` in the ``_required_parameters`` class attribute, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not determine this automatically by inspecting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jnothman asked the same, so maybe my intentions are indeed unclear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would the following be an appropriate substitute, shooting two birds with one stone: class BaseEstimator:
...
@classmethod
def _get_instances_for_checking(cls):
yield cls() ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also has the potential to make most of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for possibly separate PR. I kinda don't want to mess with the default construction test too much... |
||
which is a list or tuple. If ``_required_parameters`` is only | ||
``["estimator"]`` or ``["base_estimator"]``, then the estimator will be | ||
instantiated with an instance of ``LinearDiscriminantAnalysis`` (or | ||
``RidgeRegression`` if the estimator is a regressor) in the tests. The choice | ||
of these two models is somewhat idiosyncratic but both should provide robust | ||
closed-form solutions. | ||
|
||
.. _reading-code: | ||
|
||
Reading the existing code base | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,12 +6,25 @@ | |
import copy | ||
import warnings | ||
from collections import defaultdict | ||
from inspect import signature | ||
import struct | ||
import inspect | ||
|
||
import numpy as np | ||
|
||
from . import __version__ | ||
|
||
_DEFAULT_TAGS = { | ||
amueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'non_deterministic': False, | ||
'requires_positive_data': False, | ||
'X_types': ['2darray'], | ||
'poor_score': False, | ||
'no_validation': False, | ||
'multioutput': False, | ||
"allow_nan": False, | ||
'stateless': False, | ||
'multilabel': False, | ||
'_skip_test': False, | ||
'multioutput_only': False} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. binary only would be an important tag for external libraries (and came up in the context of the GP here). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure you're clear that it's binary targets, not fratures There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Binary only is also relevant for calibration methods. |
||
|
||
|
||
def clone(estimator, safe=True): | ||
|
@@ -61,7 +74,6 @@ def clone(estimator, safe=True): | |
return new_object | ||
|
||
|
||
############################################################################### | ||
def _pprint(params, offset=0, printer=repr): | ||
"""Pretty print the dictionary 'params' | ||
|
||
|
@@ -112,7 +124,17 @@ def _pprint(params, offset=0, printer=repr): | |
return lines | ||
|
||
|
||
############################################################################### | ||
def _update_if_consistent(dict1, dict2): | ||
common_keys = set(dict1.keys()).intersection(dict2.keys()) | ||
for key in common_keys: | ||
if dict1[key] != dict2[key]: | ||
raise TypeError("Inconsistent values for tag {}: {} != {}".format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, but then this would error if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, which I solved by having the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I think) |
||
key, dict1[key], dict2[key] | ||
)) | ||
dict1.update(dict2) | ||
return dict1 | ||
|
||
|
||
class BaseEstimator: | ||
"""Base class for all estimators in scikit-learn | ||
|
||
|
@@ -135,7 +157,7 @@ def _get_param_names(cls): | |
|
||
# introspect the constructor arguments to find the model parameters | ||
# to represent | ||
init_signature = signature(init) | ||
init_signature = inspect.signature(init) | ||
# Consider the constructor parameters excluding 'self' | ||
parameters = [p for p in init_signature.parameters.values() | ||
if p.name != 'self' and p.kind != p.VAR_KEYWORD] | ||
|
@@ -255,8 +277,22 @@ def __setstate__(self, state): | |
except AttributeError: | ||
self.__dict__.update(state) | ||
|
||
def _get_tags(self): | ||
collected_tags = {} | ||
for base_class in inspect.getmro(self.__class__): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to reverse this list to give precedence to tags set earlier in the MRO? The precedence should be tested either way. (I think the official idiom might be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm after thinking about this again, this looks like we're running into the same MRO issue that I was having earlier. I don't think @rth's solution actually works.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are right, it is method resolution order, e.g., <class '__main__.LinearRegression'>
<class '__main__.BaseEstimator'>
<class '__main__.ClassifierMixin'> So first, if a tag is defined in the first estimator, we don't want to overwrite it, i.e. for key, val in base_class._more_tags(self).items():
if key not in tags:
tags[key] = val (or something similar), instead of tags.update(base_class._more_tags(self)) Then you are right that we want to tags from within the mixin to apply before the the base estimators. Maybe we want to sort def _mro_class_compare(args):
"""Ajust the ordering for some estimators classes,
while preserving the MRO ordering for the rest"""
position_init, cls = args
offset = 0
if cls.__name__ == 'BaseEstimator':
# put the BaseEstimator last
offset = 2000
elif cls.__name__.startswith('Base'):
# put any "Base.*" classes just before
offset = 1000
return position_init + offset
# [...]
for _, base_class in sorted(enumerate(inspect.getmro(type(self))),
key=_mro_class_compare):
# setting tags here it's a bit hackish, but might work. Here the output would be,
|
||
if (hasattr(base_class, '_more_tags') | ||
and base_class != self.__class__): | ||
more_tags = base_class._more_tags(self) | ||
collected_tags = _update_if_consistent(collected_tags, | ||
more_tags) | ||
if hasattr(self, '_more_tags'): | ||
more_tags = self._more_tags() | ||
collected_tags = _update_if_consistent(collected_tags, more_tags) | ||
tags = _DEFAULT_TAGS.copy() | ||
tags.update(collected_tags) | ||
return tags | ||
|
||
|
||
############################################################################### | ||
class ClassifierMixin: | ||
"""Mixin class for all classifiers in scikit-learn.""" | ||
_estimator_type = "classifier" | ||
|
@@ -289,7 +325,6 @@ def score(self, X, y, sample_weight=None): | |
return accuracy_score(y, self.predict(X), sample_weight=sample_weight) | ||
|
||
|
||
############################################################################### | ||
class RegressorMixin: | ||
"""Mixin class for all regression estimators in scikit-learn.""" | ||
_estimator_type = "regressor" | ||
|
@@ -330,7 +365,6 @@ def score(self, X, y, sample_weight=None): | |
multioutput='variance_weighted') | ||
|
||
|
||
############################################################################### | ||
class ClusterMixin: | ||
"""Mixin class for all cluster estimators in scikit-learn.""" | ||
_estimator_type = "clusterer" | ||
|
@@ -432,7 +466,6 @@ def get_submatrix(self, i, data): | |
return data[row_ind[:, np.newaxis], col_ind] | ||
|
||
|
||
############################################################################### | ||
class TransformerMixin: | ||
"""Mixin class for all transformers in scikit-learn.""" | ||
|
||
|
@@ -510,13 +543,27 @@ def fit_predict(self, X, y=None): | |
return self.fit(X).predict(X) | ||
|
||
|
||
############################################################################### | ||
class MetaEstimatorMixin: | ||
_required_parameters = ["estimator"] | ||
"""Mixin class for all meta estimators in scikit-learn.""" | ||
amueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# this is just a tag for the moment | ||
|
||
|
||
############################################################################### | ||
class MultiOutputMixin(object): | ||
"""Mixin to mark estimators that support multioutput.""" | ||
def _more_tags(self): | ||
return {'multioutput': True} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this should set 'multilabel' if |
||
|
||
|
||
def _is_32bit(): | ||
"""Detect if process is 32bit Python.""" | ||
return struct.calcsize('P') * 8 == 32 | ||
|
||
|
||
class _UnstableOn32BitMixin(object): | ||
amueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Mark estimators that are non-determinstic on 32bit.""" | ||
def _more_tags(self): | ||
return {'non_deterministic': _is_32bit()} | ||
|
||
|
||
def is_classifier(estimator): | ||
"""Returns True if the given estimator is (probably) a classifier. | ||
|
Uh oh!
There was an error while loading. Please reload this page.