Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Incorrect sample weight handling in KBinsDiscretizer #29906

Closed
@snath-xoc

Description

@snath-xoc

Describe the bug

Sample weights are not properly passed through when specifying subsample within KBinsDiscretizer.

Steps/Code to Reproduce

from sklearn.datasets import make_blobs
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np

rng = np.random.RandomState(42)

# Four centres 
centres = np.array([[0, 0], [0, 5], [3, 1], [2, 4], [8, 8]])
X, _ = make_blobs(
            n_samples=100,
            cluster_std=0.5,
            centers=centres,
            random_state=10,
        )

# Randomly generate sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])

est = KBinsDiscretizer(n_bins=4, strategy='quantile', subsample=20,
                                    random_state=10).fit(X, sample_weight=sample_weight)

Expected Results

No error is thrown

Actual Results

[253](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/preprocessing/_discretization.py:253) if sample_weight is not None:
--> [254](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/preprocessing/_discretization.py:254)     sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
    [256](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/preprocessing/_discretization.py:256) bin_edges = np.zeros(n_features, dtype=object)
    [257](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/preprocessing/_discretization.py:257) for jj in range(n_features):

File ~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2133, in _check_sample_weight(sample_weight, X, dtype, copy, ensure_non_negative)
   [2130](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2130)         raise ValueError("Sample weights must be 1D array or scalar")
   [2132](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2132)     if sample_weight.shape != (n_samples,):
-> [2133](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2133)         raise ValueError(
   [2134](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2134)             "sample_weight.shape == {}, expected {}!".format(
   [2135](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2135)                 sample_weight.shape, (n_samples,)
   [2136](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2136)             )
   [2137](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2137)         )
   [2139](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2139) if ensure_non_negative:
   [2140](https://file+.vscode-resource.vscode-cdn.net/Users/shrutinath/sklearn-dev/~/sklearn-dev/scikit-learn/sklearn/utils/validation.py:2140)     check_non_negative(sample_weight, "`sample_weight`")

ValueError: sample_weight.shape == (100,), expected (20,)!

Versions

System:
    python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
   machine: macOS-14.3-arm64-arm-64bit

Python dependencies:
      sklearn: 1.6.dev0
          pip: 24.0
   setuptools: 70.1.1
        numpy: 2.0.0
        scipy: 1.14.0
       Cython: 3.0.10
       pandas: 2.2.2
   matplotlib: 3.9.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
...
    num_threads: 8
         prefix: libomp
       filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
        version: None
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions