-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Closed
Labels
Description
Describe the bug
While working on fixing the issue #9626, I noticed that in some cases, the current implementation of DecisionTreeRegressor(criterion="absolute_error") doesn't not find the optimal split in some cases, when sample weights are given.
It seems to only happen with a small number of points, and the chosen split is not too far from the optimal split.
My PR for #9626 will fix this one too. I'm openning this issue only to document the current behavior.
Steps/Code to Reproduce
import numpy as np
from sklearn.tree import DecisionTreeRegressor
def abs_error_of_a_leaf(y, w):
return min((np.abs(y - yi) * w).sum() for yi in y)
def abs_error_of_leaves(leaves, y, w):
return sum(abs_error_of_a_leaf(y[leaves == i], w[leaves == i]) for i in np.unique(leaves))
X = np.array([[1], [2], [3], [4], [5]])
y = np.array([1, 1, 3, 1, 2])
w = np.array([3., 3., 2., 1., 2.])
reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error')
sk_leaves = reg.fit(X, y, sample_weight=w).apply(X)
print("leaves:", sk_leaves, "total abs error:", abs_error_of_leaves(sk_leaves, y, w))
# prints [1 1 1 1 2] and 4.0
# If you look at the values of X, y, w, it's easy enough to doubt this split is the best
expected_leaves = np.array([1, 1, 2, 2, 2])
print("total abs error:", abs_error_of_leaves(expected_leaves, y, w))
# prints 3.0 => indeed, the split returned by sklearn is not the bestExpected Results
Chooses a split that minimizes the AE.
Actual Results
Prints:
leaves: [1 1 1 1 2] total abs error: 4.0
total abs error: 3.0
Showing the chosen split is not optimal.
Versions
System:
python: 3.12.11 (main, Aug 18 2025, 19:19:11) [Clang 20.1.4 ]
executable: /home/arthur/dev-perso/fast-mae-split/.venv/bin/python
machine: Linux-6.14.0-29-generic-x86_64-with-glibc2.39
Python dependencies:
sklearn: 1.7.1
pip: None
setuptools: 80.9.0
numpy: 2.2.6
scipy: 1.16.1
Cython: None
pandas: None
matplotlib: 3.10.6
joblib: 1.5.2
threadpoolctl: 3.6.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 16
prefix: libscipy_openblas
filepath: /home/arthur/dev-perso/fast-mae-split/.venv/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-56d6093b.so
version: 0.3.29
threading_layer: pthreads
architecture: Haswell
user_api: blas
internal_api: openblas
num_threads: 16
prefix: libscipy_openblas
filepath: /home/arthur/dev-perso/fast-mae-split/.venv/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
version: 0.3.28
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 16
prefix: libgomp
filepath: /home/arthur/dev-perso/fast-mae-split/.venv/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: Noneadam2392