Thanks to visit codestin.com
Credit goes to github.com

Skip to content

OneVsRestClassifier violates predict(X)==argmax(decision_function) #14124

@azrdev

Description

@azrdev

Description

When myclassifier.decision_function returns only 0s for some sample, the wrapping OneVsRestClassifier will predict the last class instead of the first one.

The cause is in multiclass.py where later classes rated 0 by the decision_function override earlier ones.
This violates the documented (and questioned, see #13631) requirement that the row-wise arg-maximum is the predicted class.

Steps/Code to Reproduce

import numpy as np
from sklearn.multiclass import OneVsRestClassifier
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.estimator_checks import check_classifiers_train
from sklearn.utils.validation import check_X_y
class Dummy(BaseEstimator, ClassifierMixin):
    def fit(self, X, y):
        X, y = check_X_y(X, y, multi_output=False)
        self.classes_ = np.unique(y)
        self.n_features_ = X.shape[1]
        return self
    def predict(self, X):
        n_features = X.shape[1]
        if self.n_features_ != n_features:
            raise ValueError("Number of features of the model must "
                             "match the input. Model n_features is %s and "
                             "input n_features is %s "
                             % (self.n_features_, n_features))
        return np.full(len(X), self.classes_[0])
    def decision_function(self, X):
        n_features = X.shape[1]
        if self.n_features_ != n_features:
            raise ValueError("Number of features of the model must "
                             "match the input. Model n_features is %s and "
                             "input n_features is %s "
                             % (self.n_features_, n_features))
        return np.zeros(len(X))

class POVR(OneVsRestClassifier):
    def _more_tags(self):
        return {'poor_score': True}
>>> x = np.array([[0], [0], [0]])
>>> y = np.arange(3)
>>> d = Dummy().fit(x,y)
>>> d.predict(x)
    array([0, 0, 0])
>>> d.decision_function(x)
    array([0., 0., 0.])
>>> od = POVR(Dummy()).fit(x,y)
>>> od.classes_
    array([0, 1, 2])
>>> od.predict(x)
    array([2, 2, 2])
>>> od.decision_function(x)
    array([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]])
>>> np.argmax(od.decision_function(x), axis=1)
    array([0, 0, 0])  # != od.predict(x)

>>> check_classifiers_train('t', POVR(Dummy()))
    ---------------------------------------------------------------------------

    AssertionError                            Traceback (most recent call last)

    <ipython-input-57-e45e548f419d> in <module>
    ----> 1 check_classifiers_train('t', POVR(Dummy()))
    

    /usr/lib/python3.7/site-packages/sklearn/utils/testing.py in wrapper(*args, **kwargs)
        353             with warnings.catch_warnings():
        354                 warnings.simplefilter("ignore", self.category)
    --> 355                 return fn(*args, **kwargs)
        356 
        357         return wrapper


    /usr/lib/python3.7/site-packages/sklearn/utils/estimator_checks.py in check_classifiers_train(name, classifier_orig, readonly_memmap)
       1493                 else:
       1494                     assert_equal(decision.shape, (n_samples, n_classes))
    -> 1495                     assert_array_equal(np.argmax(decision, axis=1), y_pred)
       1496 
       1497                 # raises error on malformed input for decision_function


    /usr/lib/python3.7/site-packages/numpy/testing/_private/utils.py in assert_array_equal(x, y, err_msg, verbose)
        902     __tracebackhide__ = True  # Hide traceback for py.test
        903     assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
    --> 904                          verbose=verbose, header='Arrays are not equal')
        905 
        906 


    /usr/lib/python3.7/site-packages/numpy/testing/_private/utils.py in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
        825                                 verbose=verbose, header=header,
        826                                 names=('x', 'y'), precision=precision)
    --> 827             raise AssertionError(msg)
        828     except ValueError:
        829         import traceback


    AssertionError: 
    Arrays are not equal
    
    Mismatch: 100%
    Max absolute difference: 2
    Max relative difference: 1.
     x: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
     y: array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,...

Versions

/usr/lib/python3.7/site-packages/numpy/distutils/system_info.py:639: UserWarning: 
    Atlas (http://math-atlas.sourceforge.net/) libraries not found.
    Directories to search for the libraries can be specified in the
    numpy/distutils/site.cfg file (section [atlas]) or by setting
    the ATLAS environment variable.
  self.calc_info()

System:
    python: 3.7.3 (default, Mar 26 2019, 21:43:19)  [GCC 8.2.1 20181127]
executable: /usr/bin/python3
   machine: Linux-5.1.11-arch1-1-ARCH-x86_64-with-arch

BLAS:
    macros: NO_ATLAS_INFO=1, HAVE_CBLAS=None
  lib_dirs: /usr/lib64
cblas_libs: cblas

Python deps:
       pip: 19.0.3
setuptools: 41.0.1
   sklearn: 0.21.2
     numpy: 1.16.4
     scipy: 1.3.0
    Cython: 0.29.10
    pandas: 0.24.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions