-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Closed
Description
Description
When myclassifier.decision_function returns only 0s for some sample, the wrapping OneVsRestClassifier will predict the last class instead of the first one.
The cause is in multiclass.py where later classes rated 0 by the decision_function override earlier ones.
This violates the documented (and questioned, see #13631) requirement that the row-wise arg-maximum is the predicted class.
Steps/Code to Reproduce
import numpy as np
from sklearn.multiclass import OneVsRestClassifier
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.estimator_checks import check_classifiers_train
from sklearn.utils.validation import check_X_yclass Dummy(BaseEstimator, ClassifierMixin):
def fit(self, X, y):
X, y = check_X_y(X, y, multi_output=False)
self.classes_ = np.unique(y)
self.n_features_ = X.shape[1]
return self
def predict(self, X):
n_features = X.shape[1]
if self.n_features_ != n_features:
raise ValueError("Number of features of the model must "
"match the input. Model n_features is %s and "
"input n_features is %s "
% (self.n_features_, n_features))
return np.full(len(X), self.classes_[0])
def decision_function(self, X):
n_features = X.shape[1]
if self.n_features_ != n_features:
raise ValueError("Number of features of the model must "
"match the input. Model n_features is %s and "
"input n_features is %s "
% (self.n_features_, n_features))
return np.zeros(len(X))
class POVR(OneVsRestClassifier):
def _more_tags(self):
return {'poor_score': True}>>> x = np.array([[0], [0], [0]])
>>> y = np.arange(3)
>>> d = Dummy().fit(x,y)
>>> d.predict(x)
array([0, 0, 0])
>>> d.decision_function(x)
array([0., 0., 0.])
>>> od = POVR(Dummy()).fit(x,y)
>>> od.classes_
array([0, 1, 2])
>>> od.predict(x)
array([2, 2, 2])
>>> od.decision_function(x)
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
>>> np.argmax(od.decision_function(x), axis=1)
array([0, 0, 0]) # != od.predict(x)
>>> check_classifiers_train('t', POVR(Dummy()))
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-57-e45e548f419d> in <module>
----> 1 check_classifiers_train('t', POVR(Dummy()))
/usr/lib/python3.7/site-packages/sklearn/utils/testing.py in wrapper(*args, **kwargs)
353 with warnings.catch_warnings():
354 warnings.simplefilter("ignore", self.category)
--> 355 return fn(*args, **kwargs)
356
357 return wrapper
/usr/lib/python3.7/site-packages/sklearn/utils/estimator_checks.py in check_classifiers_train(name, classifier_orig, readonly_memmap)
1493 else:
1494 assert_equal(decision.shape, (n_samples, n_classes))
-> 1495 assert_array_equal(np.argmax(decision, axis=1), y_pred)
1496
1497 # raises error on malformed input for decision_function
/usr/lib/python3.7/site-packages/numpy/testing/_private/utils.py in assert_array_equal(x, y, err_msg, verbose)
902 __tracebackhide__ = True # Hide traceback for py.test
903 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
--> 904 verbose=verbose, header='Arrays are not equal')
905
906
/usr/lib/python3.7/site-packages/numpy/testing/_private/utils.py in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
825 verbose=verbose, header=header,
826 names=('x', 'y'), precision=precision)
--> 827 raise AssertionError(msg)
828 except ValueError:
829 import traceback
AssertionError:
Arrays are not equal
Mismatch: 100%
Max absolute difference: 2
Max relative difference: 1.
x: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
y: array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,...Versions
/usr/lib/python3.7/site-packages/numpy/distutils/system_info.py:639: UserWarning:
Atlas (http://math-atlas.sourceforge.net/) libraries not found.
Directories to search for the libraries can be specified in the
numpy/distutils/site.cfg file (section [atlas]) or by setting
the ATLAS environment variable.
self.calc_info()
System:
python: 3.7.3 (default, Mar 26 2019, 21:43:19) [GCC 8.2.1 20181127]
executable: /usr/bin/python3
machine: Linux-5.1.11-arch1-1-ARCH-x86_64-with-arch
BLAS:
macros: NO_ATLAS_INFO=1, HAVE_CBLAS=None
lib_dirs: /usr/lib64
cblas_libs: cblas
Python deps:
pip: 19.0.3
setuptools: 41.0.1
sklearn: 0.21.2
numpy: 1.16.4
scipy: 1.3.0
Cython: 0.29.10
pandas: 0.24.2