diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 28c1cc40542e2..17dedcf5b326a 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -138,6 +138,16 @@ Support for Python 3.4 and below has been officially dropped. :pr:`12344` by :user:`Adrin Jalali `. +:mod:`sklearn.metrics` +...................... + +- |Feature| Added ``algorithm`` parameter to :func:`metrics.euclidean_distances` + to compute Euclidean distances without the quadratic expansion formula, + which is slower but more precise numerically, particularly in 32 bit. + Also added a global ``euclidean_distances_algorithm`` config parameter + with the same effect. :issue:`12136` by `Roman Yurchak`_. + + Multiple modules ................ diff --git a/sklearn/_config.py b/sklearn/_config.py index bcd206ca9a688..ccb6ece892c5a 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -5,7 +5,8 @@ _global_config = { 'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)), - 'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)) + 'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)), + 'euclidean_distance_algorithm': 'quadratic-expansion' } @@ -20,7 +21,8 @@ def get_config(): return _global_config.copy() -def set_config(assume_finite=None, working_memory=None): +def set_config(assume_finite=None, working_memory=None, + euclidean_distance_algorithm=None): """Set global scikit-learn configuration .. versionadded:: 0.19 @@ -43,11 +45,22 @@ def set_config(assume_finite=None, working_memory=None): .. versionadded:: 0.20 + euclidean_distance_algorithm : {str, None} + Method of computing the euclidean distances: "exact" uses + ``scipy.spatial.distance.cdist`` while "quadratic-expansion" uses + a faster but less precise quadratic expansion. For sparse data, only + "quadratic-expansion" is supported. + Global default: "quadratic-expansion" + + .. versionadded:: 0.21 """ if assume_finite is not None: _global_config['assume_finite'] = assume_finite if working_memory is not None: _global_config['working_memory'] = working_memory + if euclidean_distance_algorithm is not None: + _global_config['euclidean_distance_algorithm'] = ( + euclidean_distance_algorithm) @contextmanager @@ -68,6 +81,15 @@ def config_context(**new_config): computation time and memory on expensive operations that can be performed in chunks. Global default: 1024. + euclidean_distance_algorithm : {str, None} + Method of computing the euclidean distances: "exact" uses + ``scipy.spatial.distance.cdist`` while "quadratic-expansion" uses + a faster but less precise quadratic expansion. For sparse data, only + "quadratic-expansion" is supported. + Global default: "quadratic-expansion" + + .. versionadded:: 0.21 + Notes ----- All settings, not just those presently modified, will be returned to diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index f6d28252551cb..d6cfdc06fbac7 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -25,10 +25,12 @@ from ..utils import gen_batches, get_chunk_n_rows from ..utils.extmath import row_norms, safe_sparse_dot from ..preprocessing import normalize +from ..utils import get_config from ..utils._joblib import Parallel from ..utils._joblib import delayed from ..utils._joblib import effective_n_jobs + from .pairwise_fast import _chi2_kernel_fast, _sparse_manhattan @@ -163,13 +165,13 @@ def check_paired_arrays(X, Y): # Pairwise distances def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, - X_norm_squared=None): + X_norm_squared=None, algorithm=None): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. - For efficiency reasons, the euclidean distance between a pair of row - vector x and y is computed as:: + For efficiency reasons, by default, the euclidean distance between a + pair of row vector x and y is computed as:: dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y)) @@ -181,6 +183,12 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, However, this is not the most precise way of doing this computation, and the distance matrix returned by this function may not be exactly symmetric as required by, e.g., ``scipy.spatial.distance`` functions. + To use a slower but exact approach for dense data, either provide + `algorithm="exact"` or set the global ``euclidean_distance_algorithm`` + parameter:: + + with sklearn.config_context(euclidean_distance_algorithm='exact'): + knn = KNeighboursClassifier(algorithm='brute', metric='euclidean') Read more in the :ref:`User Guide `. @@ -201,6 +209,18 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, Pre-computed dot-products of vectors in X (e.g., ``(X**2).sum(axis=1)``) + algorithm : {str, None}, default: None + Method of computing the euclidean distances: "exact" uses + ``scipy.spatial.distance.cdist`` while "quadratic-expansion" uses + a faster but less precise quadratic expansion. For sparse data, only + "quadratic-expansion" is supported. + + When None (default), the value of + ``sklearn.get_config()['euclidean_distance_algorithm']`` is used ( + default: "quadratic-expansion") + + .. versionadded:: 0.21 + Returns ------- distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2) @@ -224,6 +244,20 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, """ X, Y = check_pairwise_arrays(X, Y) + if algorithm is None: + algorithm = get_config()['euclidean_distance_algorithm'] + + if algorithm not in ['exact', 'quadratic-expansion']: + raise ValueError('algorithm=%s invalid, must be one of ' + '"exact", "quadratic-expansion"' % algorithm) + + if algorithm == 'exact': + if issparse(X) or issparse(Y): + raise ValueError("algorithm='exact' does not support sparse data") + + metric = 'sqeuclidean' if squared else 'euclidean' + return distance.cdist(X, Y, metric) + if X_norm_squared is not None: XX = check_array(X_norm_squared) if XX.shape == (1, X.shape[0]): diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 05c5f48e45340..54a174dc6cff8 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -5,9 +5,11 @@ from scipy.sparse import dok_matrix, csr_matrix, issparse from scipy.spatial.distance import cosine, cityblock, minkowski, wminkowski +from scipy.spatial.distance import cdist import pytest +import sklearn from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_allclose @@ -881,3 +883,28 @@ def test_check_preserve_type(): XB.astype(np.float)) assert_equal(XA_checked.dtype, np.float) assert_equal(XB_checked.dtype, np.float) + + +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +def test_euclidean_distance_algorithm(dtype): + XA = np.random.RandomState(42).rand(100, 10).astype(dtype) + XB = np.random.RandomState(41).rand(200, 10).astype(dtype) + + dist_exact = euclidean_distances(XA, XB, algorithm='exact') + assert_allclose(dist_exact, cdist(XA, XB, 'euclidean')) + + dist_exact_squared = euclidean_distances(XA, XB, algorithm='exact', + squared=True) + + assert_allclose(dist_exact_squared, dist_exact**2) + + dist_approx = euclidean_distances(XA, XB, algorithm='quadratic-expansion') + assert_allclose(dist_exact, dist_approx, rtol=1e-5) + + with sklearn.config_context(euclidean_distance_algorithm='exact'): + assert_allclose(dist_exact, + euclidean_distances(XA, XB)) + + with pytest.raises(ValueError, + match="algorithm='exact' does not support sparse data"): + euclidean_distances(csr_matrix(XA), csr_matrix(XB), algorithm='exact') diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index efaa57f850367..930834a9b4145 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -3,14 +3,18 @@ def test_config_context(): - assert get_config() == {'assume_finite': False, 'working_memory': 1024} + assert get_config() == { + 'assume_finite': False, 'working_memory': 1024, + 'euclidean_distance_algorithm': 'quadratic-expansion'} # Not using as a context manager affects nothing config_context(assume_finite=True) assert get_config()['assume_finite'] is False with config_context(assume_finite=True): - assert get_config() == {'assume_finite': True, 'working_memory': 1024} + assert get_config() == { + 'assume_finite': True, 'working_memory': 1024, + 'euclidean_distance_algorithm': 'quadratic-expansion'} assert get_config()['assume_finite'] is False with config_context(assume_finite=True): @@ -34,7 +38,9 @@ def test_config_context(): assert get_config()['assume_finite'] is True - assert get_config() == {'assume_finite': False, 'working_memory': 1024} + assert get_config() == { + 'assume_finite': False, 'working_memory': 1024, + 'euclidean_distance_algorithm': 'quadratic-expansion'} # No positional arguments assert_raises(TypeError, config_context, True)