Closed
Description
Describe the bug
The check check_classifiers_train
fails for classifiers that need positive X, even when the tag requires_positive_X
is set to True
.
In the example below, I copy/paste the template from skltemplate and add a check in fit
to mimic the behaviour of classifiers that need positive X. The tag requires_positive_X
is set to True
in _more_tags()
Steps/Code to Reproduce
"""
This is a module to be used as a reference for building other modules
"""
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import euclidean_distances
from sklearn.utils.estimator_checks import parametrize_with_checks
class TemplateClassifier(ClassifierMixin, BaseEstimator):
""" An example classifier which implements a 1-NN algorithm.
For more information regarding how to build your own classifier, read more
in the :ref:`User Guide <user_guide>`.
Parameters
----------
demo_param : str, default='demo'
A parameter used for demonstation of how to pass and store paramters.
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
The input passed during :meth:`fit`.
y_ : ndarray, shape (n_samples,)
The labels passed during :meth:`fit`.
classes_ : ndarray, shape (n_classes,)
The classes seen at :meth:`fit`.
"""
def __init__(self, demo_param='demo'):
self.demo_param = demo_param
def fit(self, X, y):
"""A reference implementation of a fitting function for a classifier.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,)
The target values. An array of int.
Returns
-------
self : object
Returns self.
"""
# Check that X and y have correct shape
X, y = check_X_y(X, y)
# Check non-negative X for illustration
if np.any(X < 0):
raise ValueError("This classifier needs non-negative X")
# Store the classes seen during fit
self.classes_ = unique_labels(y)
self.n_features_in_ = X.shape[1]
self.X_ = X
self.y_ = y
# Return the classifier
return self
def predict(self, X):
""" A reference implementation of a prediction for a classifier.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The input samples.
Returns
-------
y : ndarray, shape (n_samples,)
The label for each sample is the label of the closest sample
seen during fit.
"""
# Check is fit had been called
check_is_fitted(self, ['X_', 'y_'])
# Input validation
X = check_array(X)
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
return self.y_[closest]
def _more_tags(self):
return {
'requires_y': True,
'requires_positive_X': True,
}
@parametrize_with_checks([TemplateClassifier()])
def test_sklearn_compatible_estimator(estimator, check):
check(estimator)
Expected Results
I would expect that checks involving negative X are skipped when tag requires_positive_X
is set to True
Actual Results
=========================================================================== short test summary info ===========================================================================
FAILED bornclassifier.py::test_sklearn_compatible_estimator[TemplateClassifier()-check_classifiers_train] - ValueError: This classifier needs non-negative X
FAILED bornclassifier.py::test_sklearn_compatible_estimator[TemplateClassifier()-check_classifiers_train(readonly_memmap=True)] - ValueError: This classifier needs non-nega...
FAILED bornclassifier.py::test_sklearn_compatible_estimator[TemplateClassifier()-check_classifiers_train(readonly_memmap=True,X_dtype=float32)] - ValueError: This classifie...
================================================================== 3 failed, 40 passed, 4 warnings in 0.99s ===================================================================
Versions
System:
python: 3.9.10 (main, Mar 16 2022, 16:37:35) [Clang 13.0.0 (clang-1300.0.27.3)]
machine: macOS-12.2.1-arm64-i386-64bit
Python dependencies:
sklearn: 1.1.2
pip: 22.0.3
setuptools: 60.6.0
numpy: 1.22.3
scipy: 1.8.0
Cython: None
pandas: 1.4.1
matplotlib: 3.5.1
joblib: 1.1.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
version: 0.3.18
threading_layer: pthreads
architecture: armv8
num_threads: 8
user_api: openmp
internal_api: openmp
prefix: libomp
version: None
num_threads: 8
user_api: blas
internal_api: openblas
prefix: libopenblas
version: 0.3.17
threading_layer: pthreads
architecture: armv8
num_threads: 8