From b700faacb4a430a664c2a9e45f95605fad5ff34e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 6 Jan 2022 13:58:29 +0100 Subject: [PATCH 01/10] MAINT Introduce FastEuclideanPairwiseArgKmin This reverts the main changes made by 09a9527 to make the initialization in __cinit__ instead of in __init__ because it's easier this way. If there's a way to maintain the initialization in __cinit__, let's do it. --- .../metrics/_pairwise_distances_reduction.pyx | 257 ++++++++++++++++-- .../test_pairwise_distances_reduction.py | 22 ++ sklearn/utils/__init__.py | 35 ++- sklearn/utils/_testing.py | 11 +- 4 files changed, 305 insertions(+), 20 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 830df08e1a952..45c8aebf031c6 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -16,7 +16,6 @@ cimport numpy as np import numpy as np import warnings -import scipy.sparse from .. import get_config from libc.stdlib cimport free, malloc @@ -24,16 +23,26 @@ from libc.float cimport DBL_MAX from cython cimport final from cython.parallel cimport parallel, prange -from ._dist_metrics cimport DatasetsPair +from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair +from ..utils._cython_blas cimport ( + BLAS_Order, + BLAS_Trans, + ColMajor, + NoTrans, + RowMajor, + Trans, + _dot, + _gemm, +) from ..utils._heap cimport simultaneous_sort, heap_push from ..utils._openmp_helpers cimport _openmp_thread_num -from ..utils._typedefs cimport ITYPE_t, DTYPE_t +from ..utils._typedefs cimport ITYPE_t, DTYPE_t, DITYPE_t from numbers import Integral from typing import List from scipy.sparse import issparse from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING -from ..utils import check_scalar +from ..utils import check_scalar, _in_unstable_openblas_configuration from ..utils.fixes import threadpool_limits from ..utils._openmp_helpers import _openmp_effective_n_threads from ..utils._typedefs import ITYPE, DTYPE @@ -41,6 +50,30 @@ from ..utils._typedefs import ITYPE, DTYPE np.import_array() +cpdef DTYPE_t[::1] _sqeuclidean_row_norms( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + DTYPE_t * X_ptr = &X[0, 0] + ITYPE_t idx = 0 + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] row_norms_squared = np.empty(n, dtype=DTYPE) + + for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): + row_norms_squared[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + + return row_norms_squared + + cdef class PairwiseDistancesReduction: """Abstract base class for pairwise distance computation & reduction. @@ -629,15 +662,16 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return pda._finalize_results(return_distance) - def __cinit__( + def __init__( self, DatasetsPair datasets_pair, chunk_size=None, n_threads=None, strategy=None, - *args, - **kwargs, - ): + ITYPE_t k=1, + ): + self.k = check_scalar(k, "k", Integral, min_val=1) + # Allocating pointers to datastructures but not the datastructures themselves. # There are as many pointers as effective threads. # @@ -654,16 +688,6 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): sizeof(ITYPE_t *) * self.chunks_n_threads ) - def __init__( - self, - DatasetsPair datasets_pair, - chunk_size=None, - n_threads=None, - strategy=None, - ITYPE_t k=1, - ): - self.k = check_scalar(k, "k", Integral, min_val=1) - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`. self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) @@ -837,3 +861,200 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return np.asarray(self.argkmin_indices), np.asarray(self.argkmin_distances) return np.asarray(self.argkmin_indices) + + +cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): + """Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance. + + The full pairwise squared distances matrix is computed as follows: + + ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||² + + The middle term gets computed efficiently bellow using BLAS Level 3 GEMM. + + Notes + ----- + This implementation has a superior arithmetic intensity and hence + better running time when the alternative is IO bound, but it can suffer + from numerical instability caused by catastrophic cancellation potentially + introduced by the subtraction in the arithmetic expression above. + + PairwiseDistancesArgKmin with EuclideanDistance must be used when higher + numerical precision is needed. + """ + + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + # Buffers for GEMM + DTYPE_t ** dist_middle_terms_chunks + bint use_squared_distances + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and + not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + ITYPE_t k, + bint use_squared_distances=False, + chunk_size=None, + strategy=None, + metric_kwargs=None, + ): + if metric_kwargs is not None and len(metric_kwargs) > 0: + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't" + f"usable for this case ({self.__class__.__name__}) and will be ignored.", + UserWarning, + stacklevel=3, + ) + + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + k=k, + chunk_size=chunk_size, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = self.datasets_pair + self.X, self.Y = datasets_pair.X, datasets_pair.Y + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared", None) + else: + self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms(self.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + + def __dealloc__(self): + if self.dist_middle_terms_chunks is not NULL: + free(self.dist_middle_terms_chunks) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesArgKmin.compute_exact_distances(self) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num) + + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self, thread_num) + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin._parallel_on_Y_init(self) + + for thread_num in range(self.chunks_n_threads): + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin._parallel_on_Y_finalize(self) + + for thread_num in range(self.chunks_n_threads): + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] + DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + DTYPE_t * A = & X_c[0, 0] + ITYPE_t lda = X_c.shape[1] + DTYPE_t * B = & Y_c[0, 0] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + DTYPE_t * C = dist_middle_terms + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, C, ldc) + + # Pushing the distance and their associated indices on heaps + # which keep tracks of the argkmin. + for i in range(X_c.shape[0]): + for j in range(Y_c.shape[0]): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * Y_c.shape[0] + j] + + self.Y_norm_squared[j + Y_start] + ), + j + Y_start, + ) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index a4d51e4662740..d293e7bf6027e 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -7,6 +7,7 @@ from sklearn.metrics._pairwise_distances_reduction import ( PairwiseDistancesReduction, PairwiseDistancesArgKmin, + _sqeuclidean_row_norms, ) from sklearn.utils.fixes import sp_version, parse_version @@ -355,3 +356,24 @@ def test_pairwise_distances_argkmin( ASSERT_RESULT[PairwiseDistancesArgKmin]( argkmin_distances, argkmin_distances_ref, argkmin_indices, argkmin_indices_ref ) + + +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("n_features", [5, 10, 100]) +@pytest.mark.parametrize("num_threads", [1, 2, 8]) +def test_sqeuclidean_row_norms( + seed, + n_samples, + n_features, + num_threads, + dtype=np.float64, +): + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + + sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2 + sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads)) + + assert_allclose(sq_row_norm_reference, sq_row_norm) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 3d8a1ca87d210..4b2261ad7c2f4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -26,7 +26,7 @@ from . import _joblib from ..exceptions import DataConversionWarning from .deprecation import deprecated -from .fixes import np_version, parse_version +from .fixes import np_version, parse_version, threadpool_info from ._estimator_html_repr import estimator_html_repr from .validation import ( as_float_array, @@ -81,6 +81,39 @@ _IS_32BIT = 8 * struct.calcsize("P") == 32 +def _in_unstable_openblas_configuration(): + """Return True if in an unstable configuration for OpenBLAS""" + + # Import libraries which might load OpenBLAS. + import numpy # noqa + import scipy # noqa + + modules_info = threadpool_info() + + open_blas_used = any(info["internal_api"] == "openblas" for info in modules_info) + if not open_blas_used: + return False + + # OpenBLAS 0.3.16 fixed unstability for arm64, see: + # https://github.com/xianyi/OpenBLAS/blob/1b6db3dbba672b4f8af935bd43a1ff6cff4d20b7/Changelog.txt#L56-L58 # noqa + openblas_arm64_stable_version = parse_version("0.3.16") + for info in modules_info: + if info["internal_api"] != "openblas": + continue + openblas_version = info.get("version") + openblas_architecture = info.get("architecture") + if openblas_version is None or openblas_architecture is None: + # Cannot be sure that OpenBLAS is good enough. Assume unstable: + return True + if ( + openblas_architecture == "neoversen1" + and parse_version(openblas_version) < openblas_arm64_stable_version + ): + # See discussions in https://github.com/numpy/numpy/issues/19411 + return True + return False + + class Bunch(dict): """Container object exposing keys as attributes. diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 1724063be2f43..6f58ce3f3b7b4 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -48,7 +48,12 @@ import joblib import sklearn -from sklearn.utils import IS_PYPY, _IS_32BIT, deprecated +from sklearn.utils import ( + IS_PYPY, + _IS_32BIT, + deprecated, + _in_unstable_openblas_configuration, +) from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import ( check_array, @@ -448,6 +453,10 @@ def set_random_state(estimator, random_state=0): os.environ.get("TRAVIS") == "true", reason="skip on travis" ) fails_if_pypy = pytest.mark.xfail(IS_PYPY, reason="not compatible with PyPy") + fails_if_unstable_openblas = pytest.mark.xfail( + _in_unstable_openblas_configuration(), + reason="OpenBLAS is unstable for this configuration", + ) skip_if_no_parallel = pytest.mark.skipif( not joblib.parallel.mp, reason="joblib is in serial mode" ) From c7be245986a6b4a7a0c5a3a0ec80d46dae9b1792 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 6 Jan 2022 14:09:39 +0100 Subject: [PATCH 02/10] fixup! MAINT Introduce FastEuclideanPairwiseArgKmin --- sklearn/metrics/_pairwise_distances_reduction.pyx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 45c8aebf031c6..f04f986ce1be5 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -62,16 +62,18 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( # Casting for X to remove the const qualifier is needed because APIs # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' # const qualifier. + # See: https://github.com/scipy/scipy/issues/1426 DTYPE_t * X_ptr = &X[0, 0] ITYPE_t idx = 0 ITYPE_t n = X.shape[0] ITYPE_t d = X.shape[1] - DTYPE_t[::1] row_norms_squared = np.empty(n, dtype=DTYPE) + DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): - row_norms_squared[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + + return squared_row_norms - return row_norms_squared cdef class PairwiseDistancesReduction: From 0ff95cfa6f1f65be82fdf36b5917046dfe45bf6f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 6 Jan 2022 14:28:37 +0100 Subject: [PATCH 03/10] Fix imports and typos --- sklearn/metrics/_pairwise_distances_reduction.pyx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index f04f986ce1be5..9e828bd8368be 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -36,7 +36,7 @@ from ..utils._cython_blas cimport ( ) from ..utils._heap cimport simultaneous_sort, heap_push from ..utils._openmp_helpers cimport _openmp_thread_num -from ..utils._typedefs cimport ITYPE_t, DTYPE_t, DITYPE_t +from ..utils._typedefs cimport ITYPE_t, DTYPE_t from numbers import Integral from typing import List @@ -62,7 +62,7 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( # Casting for X to remove the const qualifier is needed because APIs # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' # const qualifier. - # See: https://github.com/scipy/scipy/issues/1426 + # See: https://github.com/scipy/scipy/issues/14262 DTYPE_t * X_ptr = &X[0, 0] ITYPE_t idx = 0 ITYPE_t n = X.shape[0] @@ -80,7 +80,7 @@ cdef class PairwiseDistancesReduction: """Abstract base class for pairwise distance computation & reduction. Subclasses of this class compute pairwise distances between a set of - row vectors of X and another set of row vectors pf Y and apply a reduction on top. + row vectors of X and another set of row vectors of Y and apply a reduction on top. The reduction takes a matrix of pairwise distances between rows of X and Y as input and outputs an aggregate data-structure for each row of X. The aggregate values are typically smaller than the number of rows in Y, @@ -1030,6 +1030,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): DTYPE_t alpha = - 2. # Casting for A and B to remove the const is needed because APIs exposed via # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 DTYPE_t * A = & X_c[0, 0] ITYPE_t lda = X_c.shape[1] DTYPE_t * B = & Y_c[0, 0] From 2f612784e55a84bc864df940139e791dea22a7f4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 7 Jan 2022 18:09:11 +0100 Subject: [PATCH 04/10] TST Adapt test_pairwise_distances_argkmin Co-authored-by: Olivier Grisel --- .../metrics/tests/test_pairwise_distances_reduction.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index d293e7bf6027e..e975aad55bf9c 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -10,6 +10,7 @@ _sqeuclidean_row_norms, ) +from sklearn.metrics import euclidean_distances from sklearn.utils.fixes import sp_version, parse_version # Common supported metric between scipy.spatial.distance.cdist @@ -306,7 +307,7 @@ def test_strategies_consistency( @pytest.mark.parametrize("n_features", [50, 500]) -@pytest.mark.parametrize("translation", [0, 1e8]) +@pytest.mark.parametrize("translation", [0, 1e6]) @pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) @pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) def test_pairwise_distances_argkmin( @@ -331,7 +332,11 @@ def test_pairwise_distances_argkmin( metric_kwargs = _get_dummy_metric_params_list(metric, n_features)[0] # Reference for argkmin results - dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) + if metric == "euclidean": + # Compare to scikit-learn GEMM optimized implementation + dist_matrix = euclidean_distances(X, Y) + else: + dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) # Taking argkmin (indices of the k smallest values) argkmin_indices_ref = np.argsort(dist_matrix, axis=1)[:, :k] # Getting the associated distances From 03b6c79e04c8d0a23a23493a6426907dd982a9ea Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 7 Jan 2022 18:14:12 +0100 Subject: [PATCH 05/10] MAINT Reorder initilizations to move allocations in __cinit__ Co-authored-by: Thomas J. Fan Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction.pyx | 129 ++++++++++++++---- 1 file changed, 102 insertions(+), 27 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 9e828bd8368be..39670619f2808 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -215,9 +215,20 @@ cdef class PairwiseDistancesReduction: not issparse(Y) and Y.dtype == np.float64 and metric in cls.valid_metrics()) + # About __cinit__ and __init__ signatures: + # + # - __cinit__ is responsible for C-level allocations and initializations + # - __init__ is responsible for PyObject initialization + # - for a given class, __cinit__ and __init__ must have a matching signatures + # (up to *args and **kwargs) + # - for a given class hiearchy __cinit__'s must have a matching signatures + # (up to *args and **kwargs) + # + # See: https://cython.readthedocs.io/en/latest/src/userguide/special_methods.html#initialisation-methods-cinit-and-init #noqa def __cinit__( self, - DatasetsPair datasets_pair, + n_samples_X, + n_samples_Y, chunk_size=None, n_threads=None, strategy=None, @@ -234,9 +245,7 @@ cdef class PairwiseDistancesReduction: self.effective_n_threads = _openmp_effective_n_threads(n_threads) - self.datasets_pair = datasets_pair - - self.n_samples_X = datasets_pair.n_samples_X() + self.n_samples_X = n_samples_X self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk @@ -247,7 +256,7 @@ cdef class PairwiseDistancesReduction: else: self.X_n_samples_last_chunk = self.X_n_samples_chunk - self.n_samples_Y = datasets_pair.n_samples_Y() + self.n_samples_Y = n_samples_Y self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk @@ -281,6 +290,17 @@ cdef class PairwiseDistancesReduction: self.effective_n_threads, ) + def __init__( + self, + n_samples_X, + n_samples_Y, + chunk_size=None, + n_threads=None, + strategy=None, + DatasetsPair datasets_pair=None, + ): + self.datasets_pair = datasets_pair + @final cdef void _parallel_on_X(self) nogil: """Compute the pairwise distances of each row vector of X on Y @@ -647,12 +667,30 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): # for various back-end and/or hardware and/or datatypes, and/or fused # {sparse, dense}-datasetspair etc. - pda = PairwiseDistancesArgKmin( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - ) + if metric in ("euclidean", "sqeuclidean") and not issparse(X) and not issparse(Y): + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesArgKmin( + n_samples_X=X.shape[0], + n_samples_Y=Y.shape[0], + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + X=X, + Y=Y, + k=k, + use_squared_distances=use_squared_distances, + metric_kwargs=metric_kwargs, + ) + else: + pda = PairwiseDistancesArgKmin( + n_samples_X=X.shape[0], + n_samples_Y=Y.shape[0], + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + k=k, + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -664,16 +702,16 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return pda._finalize_results(return_distance) - def __init__( + def __cinit__( self, - DatasetsPair datasets_pair, + n_samples_X, + n_samples_Y, chunk_size=None, n_threads=None, strategy=None, - ITYPE_t k=1, - ): - self.k = check_scalar(k, "k", Integral, min_val=1) - + *args, + **kwargs, + ): # Allocating pointers to datastructures but not the datastructures themselves. # There are as many pointers as effective threads. # @@ -690,6 +728,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): sizeof(ITYPE_t *) * self.chunks_n_threads ) + def __init__( + self, + n_samples_X, + n_samples_Y, + chunk_size=None, + n_threads=None, + strategy=None, + DatasetsPair datasets_pair=None, + ITYPE_t k=1, + ): + super().__init__( + n_samples_X=n_samples_X, + n_samples_Y=n_samples_Y, + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + datasets_pair=datasets_pair, + ) + self.k = check_scalar(k, "k", Integral, min_val=1) + # Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`. self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) @@ -900,14 +958,32 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) + def __cinit__( + self, + n_samples_X, + n_samples_Y, + chunk_size=None, + n_threads=None, + strategy=None, + *args, + **kwargs, + ): + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + def __init__( self, - X, - Y, - ITYPE_t k, - bint use_squared_distances=False, + n_samples_X, + n_samples_Y, chunk_size=None, + n_threads=None, strategy=None, + X=None, + Y=None, + ITYPE_t k=1, + bint use_squared_distances=False, metric_kwargs=None, ): if metric_kwargs is not None and len(metric_kwargs) > 0: @@ -919,10 +995,14 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) super().__init__( + n_samples_X=n_samples_X, + n_samples_Y=n_samples_Y, + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, # The datasets pair here is used for exact distances computations datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), k=k, - chunk_size=chunk_size, ) # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair cdef: @@ -941,11 +1021,6 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) self.use_squared_distances = use_squared_distances - # Temporary datastructures used in threads - self.dist_middle_terms_chunks = malloc( - sizeof(DTYPE_t *) * self.chunks_n_threads - ) - def __dealloc__(self): if self.dist_middle_terms_chunks is not NULL: free(self.dist_middle_terms_chunks) From 9b88e35d397df0008b6782c14bdaaca9bd010057 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 10 Jan 2022 12:55:45 +0100 Subject: [PATCH 06/10] Revert "MAINT Reorder initilizations to move allocations in __cinit__" This reverts commit 03b6c79e04c8d0a23a23493a6426907dd982a9ea. --- .../metrics/_pairwise_distances_reduction.pyx | 129 ++++-------------- 1 file changed, 27 insertions(+), 102 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 39670619f2808..9e828bd8368be 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -215,20 +215,9 @@ cdef class PairwiseDistancesReduction: not issparse(Y) and Y.dtype == np.float64 and metric in cls.valid_metrics()) - # About __cinit__ and __init__ signatures: - # - # - __cinit__ is responsible for C-level allocations and initializations - # - __init__ is responsible for PyObject initialization - # - for a given class, __cinit__ and __init__ must have a matching signatures - # (up to *args and **kwargs) - # - for a given class hiearchy __cinit__'s must have a matching signatures - # (up to *args and **kwargs) - # - # See: https://cython.readthedocs.io/en/latest/src/userguide/special_methods.html#initialisation-methods-cinit-and-init #noqa def __cinit__( self, - n_samples_X, - n_samples_Y, + DatasetsPair datasets_pair, chunk_size=None, n_threads=None, strategy=None, @@ -245,7 +234,9 @@ cdef class PairwiseDistancesReduction: self.effective_n_threads = _openmp_effective_n_threads(n_threads) - self.n_samples_X = n_samples_X + self.datasets_pair = datasets_pair + + self.n_samples_X = datasets_pair.n_samples_X() self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk @@ -256,7 +247,7 @@ cdef class PairwiseDistancesReduction: else: self.X_n_samples_last_chunk = self.X_n_samples_chunk - self.n_samples_Y = n_samples_Y + self.n_samples_Y = datasets_pair.n_samples_Y() self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk @@ -290,17 +281,6 @@ cdef class PairwiseDistancesReduction: self.effective_n_threads, ) - def __init__( - self, - n_samples_X, - n_samples_Y, - chunk_size=None, - n_threads=None, - strategy=None, - DatasetsPair datasets_pair=None, - ): - self.datasets_pair = datasets_pair - @final cdef void _parallel_on_X(self) nogil: """Compute the pairwise distances of each row vector of X on Y @@ -667,30 +647,12 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): # for various back-end and/or hardware and/or datatypes, and/or fused # {sparse, dense}-datasetspair etc. - if metric in ("euclidean", "sqeuclidean") and not issparse(X) and not issparse(Y): - use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesArgKmin( - n_samples_X=X.shape[0], - n_samples_Y=Y.shape[0], - chunk_size=chunk_size, - n_threads=n_threads, - strategy=strategy, - X=X, - Y=Y, - k=k, - use_squared_distances=use_squared_distances, - metric_kwargs=metric_kwargs, - ) - else: - pda = PairwiseDistancesArgKmin( - n_samples_X=X.shape[0], - n_samples_Y=Y.shape[0], - chunk_size=chunk_size, - n_threads=n_threads, - strategy=strategy, - k=k, - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - ) + pda = PairwiseDistancesArgKmin( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -702,16 +664,16 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return pda._finalize_results(return_distance) - def __cinit__( + def __init__( self, - n_samples_X, - n_samples_Y, + DatasetsPair datasets_pair, chunk_size=None, n_threads=None, strategy=None, - *args, - **kwargs, - ): + ITYPE_t k=1, + ): + self.k = check_scalar(k, "k", Integral, min_val=1) + # Allocating pointers to datastructures but not the datastructures themselves. # There are as many pointers as effective threads. # @@ -728,26 +690,6 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): sizeof(ITYPE_t *) * self.chunks_n_threads ) - def __init__( - self, - n_samples_X, - n_samples_Y, - chunk_size=None, - n_threads=None, - strategy=None, - DatasetsPair datasets_pair=None, - ITYPE_t k=1, - ): - super().__init__( - n_samples_X=n_samples_X, - n_samples_Y=n_samples_Y, - chunk_size=chunk_size, - n_threads=n_threads, - strategy=strategy, - datasets_pair=datasets_pair, - ) - self.k = check_scalar(k, "k", Integral, min_val=1) - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`. self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) @@ -958,32 +900,14 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) - def __cinit__( - self, - n_samples_X, - n_samples_Y, - chunk_size=None, - n_threads=None, - strategy=None, - *args, - **kwargs, - ): - # Temporary datastructures used in threads - self.dist_middle_terms_chunks = malloc( - sizeof(DTYPE_t *) * self.chunks_n_threads - ) - def __init__( self, - n_samples_X, - n_samples_Y, + X, + Y, + ITYPE_t k, + bint use_squared_distances=False, chunk_size=None, - n_threads=None, strategy=None, - X=None, - Y=None, - ITYPE_t k=1, - bint use_squared_distances=False, metric_kwargs=None, ): if metric_kwargs is not None and len(metric_kwargs) > 0: @@ -995,14 +919,10 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) super().__init__( - n_samples_X=n_samples_X, - n_samples_Y=n_samples_Y, - chunk_size=chunk_size, - n_threads=n_threads, - strategy=strategy, # The datasets pair here is used for exact distances computations datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), k=k, + chunk_size=chunk_size, ) # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair cdef: @@ -1021,6 +941,11 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) self.use_squared_distances = use_squared_distances + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + def __dealloc__(self): if self.dist_middle_terms_chunks is not NULL: free(self.dist_middle_terms_chunks) From e27bb755f1b5d70a1067c7a7b30e0bbcd1d365c4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 10 Jan 2022 15:06:38 +0100 Subject: [PATCH 07/10] Simply plug FastEuclideanPairwiseDistancesArgKmin in --- .../metrics/_pairwise_distances_reduction.pyx | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 9e828bd8368be..891dd16d799c9 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -215,14 +215,12 @@ cdef class PairwiseDistancesReduction: not issparse(Y) and Y.dtype == np.float64 and metric in cls.valid_metrics()) - def __cinit__( + def __init__( self, DatasetsPair datasets_pair, chunk_size=None, n_threads=None, strategy=None, - *args, - **kwargs, ): cdef: ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks @@ -646,13 +644,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): # For future work, this might can be an entrypoint to specialise operations # for various back-end and/or hardware and/or datatypes, and/or fused # {sparse, dense}-datasetspair etc. - - pda = PairwiseDistancesArgKmin( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - ) + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesArgKmin( + X=X, Y=Y, k=k, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + strategy=strategy, + metric_kwargs=metric_kwargs, + ) + else: # Fall back on the default + pda = PairwiseDistancesArgKmin( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -672,6 +683,12 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): strategy=None, ITYPE_t k=1, ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + ) self.k = check_scalar(k, "k", Integral, min_val=1) # Allocating pointers to datastructures but not the datastructures themselves. @@ -907,6 +924,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ITYPE_t k, bint use_squared_distances=False, chunk_size=None, + n_threads=None, strategy=None, metric_kwargs=None, ): @@ -921,8 +939,10 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): super().__init__( # The datasets pair here is used for exact distances computations datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), - k=k, chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + k=k, ) # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair cdef: From a0b46eba1cf5a931e0caf87bf4ea3cfbbac16f63 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 10 Jan 2022 15:39:04 +0100 Subject: [PATCH 08/10] Remove old remark --- sklearn/metrics/_pairwise_distances_reduction.pyx | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 891dd16d799c9..9fa2b32a2e80b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -897,9 +897,6 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): better running time when the alternative is IO bound, but it can suffer from numerical instability caused by catastrophic cancellation potentially introduced by the subtraction in the arithmetic expression above. - - PairwiseDistancesArgKmin with EuclideanDistance must be used when higher - numerical precision is needed. """ cdef: From 7805e8137972ddc32a0519487c4425ff1d2b6795 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 12 Jan 2022 10:51:00 +0100 Subject: [PATCH 09/10] Add indications for the dispatch of the implementations Co-authored-by: Olivier Grisel --- sklearn/metrics/_pairwise_distances_reduction.pyx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 9fa2b32a2e80b..72b78ce9c42ee 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -649,6 +649,10 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): and not issparse(X) and not issparse(Y) ): + # Specialized implementation with improved arithmetic intensity + # and vector instructions (SIMD) by processing several vectors + # at time to leverage a call to the BLAS GEMM routine as explained + # in more details in the docstring. use_squared_distances = metric == "sqeuclidean" pda = FastEuclideanPairwiseDistancesArgKmin( X=X, Y=Y, k=k, @@ -657,7 +661,9 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): strategy=strategy, metric_kwargs=metric_kwargs, ) - else: # Fall back on the default + else: + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. pda = PairwiseDistancesArgKmin( datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), k=k, From a4f39cad7ebef4a17da02eadb29c033c796d939d Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 13 Jan 2022 16:08:20 +0100 Subject: [PATCH 10/10] Simplify and do not use indirection Co-authored-by: Thomas J. Fan --- sklearn/metrics/_pairwise_distances_reduction.pyx | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 72b78ce9c42ee..76420b50a1b5e 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -953,7 +953,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): self.X, self.Y = datasets_pair.X, datasets_pair.Y if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared", None) + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) @@ -1059,11 +1059,10 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): DTYPE_t * B = & Y_c[0, 0] ITYPE_t ldb = X_c.shape[1] DTYPE_t beta = 0. - DTYPE_t * C = dist_middle_terms ITYPE_t ldc = Y_c.shape[0] # dist_middle_terms = `-2 * X_c @ Y_c.T` - _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, C, ldc) + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) # Pushing the distance and their associated indices on heaps # which keep tracks of the argkmin.