-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Closed
Labels
Description
Describe the bug
When running euclidean_distances I think it is possible to get to this line of code with XX being None. This will happen when the input X and Y are float64 but X_norm_squared and Y_norm_squared are float32.
Steps/Code to Reproduce
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
x = np.random.randint(0, 10, size=(100, 5)).astype(np.float64)
y = np.random.randint(0, 10, size=(100, 5)).astype(np.float64)
xx = np.einsum("ij,ij->i", x, x)
yy = np.einsum("ij,ij->i", y, y)
euclidean_distances(x, y, Y_norm_squared=yy.astype(np.float32), X_norm_squared=xx.astype(np.float32))Expected Results
No error is thrown
Actual Results
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/ageorgiou/projects/test/venv/lib/python3.10/site-packages/sklearn/metrics/pairwise.py", line 338, in euclidean_distances
return _euclidean_distances(X, Y, X_norm_squared, Y_norm_squared, squared)
File "/home/ageorgiou/projects/test/venv/lib/python3.10/site-packages/sklearn/metrics/pairwise.py", line 380, in _euclidean_distances
distances += XX
numpy.core._exceptions._UFuncOutputCastingError: Cannot cast ufunc 'add' output from dtype('O') to dtype('float64') with casting rule 'same_kind'
Versions
System:
python: 3.10.6 (main, Oct 30 2022, 13:35:37) [GCC 12.2.0]
executable: /home/ageorgiou/projects/test/venv/bin/python
machine: Linux-6.2.12-1-MANJARO-x86_64-with-glibc2.38
Python dependencies:
sklearn: 1.3.1
pip: 22.2.1
setuptools: 63.2.0
numpy: 1.26.1
scipy: 1.11.3
Cython: None
pandas: None
matplotlib: None
joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /home/ageorgiou/projects/test/venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
version: 0.3.23.dev
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libgomp
filepath: /home/ageorgiou/projects/test/venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /home/ageorgiou/projects/test/venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so
version: 0.3.21.dev
threading_layer: pthreads
architecture: Haswell