-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Describe the bug
utils.multiclass.type_of_target returns 'multiclass' when y is the target of a regression problem that only takes integer values, even when cast to "float".
This appeared when I was working on the Ames housing dataset where prices are integers but represent a regression task. When splitting the dataset into train and test, the number of samples becomes less than half the number of 'classes' i.e. unique values so type_of_target raises the warning
"The number of unique classes is greater than 50% of the number "
"of samples. y could represent a regression problem, not a "
"classification problem."
This happened while using RandomForestRegressor with oob_score=True which calls type_of_target(y). It can be fixed on the user side by adding some small non integer float to all elements of the target but I think this should be fixed on the library side. Note that casting the target type to "float" does not fix the issue.
Steps/Code to Reproduce
from sklearn.datasets import fetch_openml
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.utils.multiclass import type_of_target
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
y = y.astype("float")
X = X.drop(columns="Id")
X = X.select_dtypes("number")
imputer = SimpleImputer(strategy="mean")
X = imputer.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
print(type_of_target(y_test))
Expected Results
Outputs 'continuous' with no warning.
Actual Results
typeoftarget_bug_reproducer.py:17: UserWarning: The number of unique classes is greater than 50% of the number of samples. `y` could represent a regression problem, not a classification problem.
print(type_of_target(y_test))
multiclass
Versions
System:
python: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0]
executable: /home/gaeta/.sklearn-env/bin/python3
machine: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
Python dependencies:
sklearn: 1.8.dev0
pip: 25.1.1
setuptools: 80.8.0
numpy: 2.2.6
scipy: 1.15.3
Cython: 3.1.3
pandas: 2.2.3
matplotlib: 3.6.0
joblib: 1.5.2
threadpoolctl: 3.6.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 14
prefix: libscipy_openblas
filepath: /home/gaeta/.sklearn-env/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-56d6093b.so
version: 0.3.29
threading_layer: pthreads
architecture: Haswell
user_api: blas
internal_api: openblas
num_threads: 14
prefix: libscipy_openblas
filepath: /home/gaeta/.sklearn-env/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
version: 0.3.28
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 14
prefix: libgomp
filepath: /usr/lib/x86_64-linux-gnu/libgomp.so.1.0.0
version: None