-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Closed
Description
Describe the bug
Since sklearn 1.4, using e.g. train_test_split on astropy.table.Table objects raises an exception:
Steps/Code to Reproduce
from astropy.table import Table
from sklearn.model_selection import KFold
from astropy import __version__ as astropy_version
from sklearn import __version__ as sklearn_version
print(f"sklearn: {sklearn_version}, astropy: {astropy_version}")
t = Table({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
fold = KFold(2)
print(next(fold.split(t)))Expected Results
❯ python sklearn_astropy.py
sklearn: 1.3.2, astropy: 6.0.0
(array([2, 3]), array([0, 1]))
Actual Results
sklearn: 1.4.0, astropy: 6.0.0
Traceback (most recent call last):
File "/home/maxnoe/test/astropy_skllearn_1.4/sklearn_astropy.py", line 12, in <module>
print(next(fold.split(t)))
^^^^^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/model_selection/_split.py", line 367, in split
X, y, groups = indexable(X, y, groups)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 476, in indexable
check_consistent_length(*result)
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 427, in check_consistent_length
lengths = [_num_samples(X) for X in arrays if X is not None]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 427, in <listcomp>
lengths = [_num_samples(X) for X in arrays if X is not None]
^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 351, in _num_samples
if _use_interchange_protocol(x):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 288, in _use_interchange_protocol
return not _is_pandas_df(X) and hasattr(X, "__dataframe__")
^^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 2073, in _is_pandas_df
if hasattr(X, "columns") and hasattr(X, "iloc"):
^^^^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/astropy/table/table.py", line 1046, in iloc
return TableILoc(self)
^^^^^^^^^^^^^^^
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/astropy/table/index.py", line 949, in __init__
super().__init__(table)
File "/home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/astropy/table/index.py", line 832, in __init__
raise ValueError("Cannot create TableLoc object with no indices")
ValueError: Cannot create TableLoc object with no indices
Versions
System:
python: 3.11.6 (main, Nov 14 2023, 09:36:21) [GCC 13.2.1 20230801]
executable: /home/maxnoe/test/astropy_skllearn_1.4/venv/bin/python
machine: Linux-6.1.70-1-MANJARO-x86_64-with-glibc2.38
Python dependencies:
sklearn: 1.4.0
pip: 23.2.1
setuptools: 65.5.0
numpy: 1.26.3
scipy: 1.11.4
Cython: None
pandas: None
matplotlib: None
joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libgomp
filepath: /home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
version: 0.3.23.dev
threading_layer: pthreads
architecture: Haswell
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /home/maxnoe/test/astropy_skllearn_1.4/venv/lib/python3.11/site-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so
version: 0.3.21.dev
threading_layer: pthreads
architecture: Haswell