From 23c175b248f5fc90f314f1d078cee629fe449144 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 1 Jun 2022 17:50:05 +0200 Subject: [PATCH 01/19] MAINT Introduce interfaces for PairwiseDistancesReductions Those interfaces are meant to be used in the Python code, decoupling the actual implementation from the Python code. This allows changing all the private implementation while maintaining a contract for the Python callers. Each interface extending the base `PairwiseDistancesReduction` interface must implement the :meth:`compute` classmethod. Under the hood, such a function must only define the logic to dispatch at runtime to the correct dtype-specialized `PairwiseDistancesReduction` implementation based on the dtype of X and of Y. This refactoring will ease other dtype support such as float32 support. --- .../metrics/_pairwise_distances_reduction.pyx | 971 +++++++++++++----- 1 file changed, 715 insertions(+), 256 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 9191efae2a8da..abbcaaf8b5678 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -25,7 +25,6 @@ from libcpp.vector cimport vector from cython cimport final from cython.operator cimport dereference as deref from cython.parallel cimport parallel, prange -from cpython.ref cimport Py_INCREF from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair from ..utils._cython_blas cimport ( @@ -53,7 +52,6 @@ from ..utils.fixes import threadpool_limits from ..utils._openmp_helpers import _openmp_effective_n_threads from ..utils._typedefs import ITYPE, DTYPE - cnp.import_array() # TODO: change for `libcpp.algorithm.move` once Cython 3 is used @@ -82,8 +80,7 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( """Coerce a std::vector of std::vector to a ndarray of ndarray.""" cdef: ITYPE_t n = deref(vecs).size() - cnp.ndarray[object, ndim=1] nd_arrays_of_nd_arrays = np.empty(n, - dtype=np.ndarray) + cnp.ndarray[object, ndim=1] nd_arrays_of_nd_arrays = np.empty(n, dtype=np.ndarray) for i in range(n): nd_arrays_of_nd_arrays[i] = vector_to_nd_array(&(deref(vecs)[i])) @@ -117,7 +114,19 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( return squared_row_norms ##################### +# Interfaces: +# Those interfaces are meant to be used in the Python code, decoupling the +# actual implementation from the Python code. This allows changing all the +# private implementation while maintaining a contract for the Python callers. +# +# Each interface extending the base `PairwiseDistancesReduction` interface must +# implement the :meth:`compute` classmethod. +# +# Under the hood, such a function must only define the logic to dispatch +# at runtime to the correct dtype-specialized `PairwiseDistancesReduction` +# implementation based on the dtype of X and of Y. +# Base interface cdef class PairwiseDistancesReduction: """Abstract base class for pairwise distance computation & reduction. @@ -183,32 +192,6 @@ cdef class PairwiseDistancesReduction: `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. """ - cdef: - readonly DatasetsPair datasets_pair - - # The number of threads that can be used is stored in effective_n_threads. - # - # The number of threads to use in the parallelisation strategy - # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: - # for small datasets, less threads might be needed to loop over pair of chunks. - # - # Hence the number of threads that _will_ be used for looping over chunks - # is stored in chunks_n_threads, allowing solely using what we need. - # - # Thus, an invariant is: - # - # chunks_n_threads <= effective_n_threads - # - ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - - ITYPE_t n_samples_chunk, chunk_size - - ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk - ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk - - bint execute_in_parallel_on_Y - @classmethod def valid_metrics(cls) -> List[str]: excluded = { @@ -244,12 +227,479 @@ cdef class PairwiseDistancesReduction: ------- True if the PairwiseDistancesReduction can be used, else False. """ - # TODO: support sparse arrays and 32 bits + dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 return (get_config().get("enable_cython_pairwise_dist", True) and - not issparse(X) and X.dtype == np.float64 and - not issparse(Y) and Y.dtype == np.float64 and + not issparse(X) and not issparse(Y) and dtypes_validity and metric in cls.valid_metrics()) + +cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): + """Compute the argkmin of row vectors of X on the ones of Y. + + For each row vector of X, computes the indices of k first the rows + vectors of Y with the smallest distances. + + PairwiseDistancesArgKmin is typically used to perform + bruteforce k-nearest neighbors queries. + + Parameters + ---------- + datasets_pair: DatasetsPair + The dataset pairs (X, Y) for the reduction. + + chunk_size: int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + k: int, default=1 + The k for the argkmin reduction. + """ + + @classmethod + def compute( + cls, + X, + Y, + ITYPE_t k, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + bint return_distance=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + k : int + The k for the argkmin reduction. + + metric : str, default='euclidean' + The distance metric to use for argkmin. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + If return_distance=False: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + If return_distance=True: + - argkmin_distances : ndarray of shape (n_samples_X, k) + Distances to the argkmin for each vector in X. + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistancesArgKmin`. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows decoupling the interface entirely from the + implementation details whilst maintaining RAII. + """ + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesArgKmin64.compute( + X=X, + Y=Y, + k=k, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + return_distance=return_distance, + ) + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + +cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): + """Compute radius-based neighbors for two sets of vectors. + + For each row-vector X[i] of the queries X, find all the indices j of + row-vectors in Y such that: + + dist(X[i], Y[j]) <= radius + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + Parameters + ---------- + datasets_pair: DatasetsPair + The dataset pair (X, Y) for the reduction. + + chunk_size: int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + radius: float + The radius defining the neighborhood. + """ + + @classmethod + def compute( + cls, + X, + Y, + DTYPE_t radius, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + bint return_distance=False, + bint sort_results=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + radius : float + The radius defining the neighborhood. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its neighbors if set to True. + + sort_results : boolean, default=False + Sort results with respect to distances between each X vector and its + neighbors if set to True. + + Returns + ------- + If return_distance=False: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + + If return_distance=True: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + - neighbors_distances : ndarray of n_samples_X ndarray + Distances to the neighbors for each vector in X. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistancesRadiusNeighborhood`. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the interface entirely from the + implementation details whilst maintaining RAII. + """ + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesRadiusNeighborhood64.compute( + X=X, + Y=Y, + radius=radius, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + return_distance=return_distance, + ) + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + +##################### +# dtype-specific implementations: +# For each dtype, an implementation of `PairwiseDistancesReductions` are generated by Tempita. +# Computations are dispatched to them at runtime via the interfaces defined above. +# +# Also, other helper are dtype-specialised. + +cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * X_ptr = &X[0, 0] + ITYPE_t i = 0 + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) + + for i in prange(n, schedule='static', nogil=True, num_threads=num_threads): + squared_row_norms[i] = _dot(d, X_ptr + i * d, 1, X_ptr + i * d, 1) + + return squared_row_norms + +cdef class GEMMTermComputer64: + """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. + + `FastEuclidean*` classes internally compute the squared Euclidean distances between + chunks of vectors X_c and Y_c using the following decomposition: + + + ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + + + This helper class is in charge of wrapping the common logic to compute + the middle term `- 2 X_c_i.Y_c_j^T` with a call to GEMM, which has a high + arithmetic intensity. + """ + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + ITYPE_t dist_middle_terms_chunks_size + ITYPE_t n_features + ITYPE_t chunk_size + + # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM + vector[vector[DTYPE_t]] dist_middle_terms_chunks + + def __init__(self, + DTYPE_t[:, ::1] X, + DTYPE_t[:, ::1] Y, + ITYPE_t effective_n_threads, + ITYPE_t chunks_n_threads, + ITYPE_t dist_middle_terms_chunks_size, + ITYPE_t n_features, + ITYPE_t chunk_size, + ): + self.X = X + self.Y = Y + self.effective_n_threads = effective_n_threads + self.chunks_n_threads = chunks_n_threads + self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size + self.n_features = n_features + self.chunk_size = chunk_size + + self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) + + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + return + + cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: + self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _parallel_on_Y_init(self) nogil: + for thread_num in range(self.chunks_n_threads): + self.dist_middle_terms_chunks[thread_num].resize( + self.dist_middle_terms_chunks_size + ) + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num + ) nogil: + return + + cdef DTYPE_t * _compute_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * A = &X_c[0, 0] + DTYPE_t * B = &Y_c[0, 0] + ITYPE_t lda = X_c.shape[1] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) + + return dist_middle_terms + +cdef class PairwiseDistancesReduction64(PairwiseDistancesReduction): + """64bit implementation of PairwiseDistancesReduction.""" + + cdef: + readonly DatasetsPair datasets_pair + + # The number of threads that can be used is stored in effective_n_threads. + # + # The number of threads to use in the parallelisation strategy + # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: + # for small datasets, less threads might be needed to loop over pair of chunks. + # + # Hence the number of threads that _will_ be used for looping over chunks + # is stored in chunks_n_threads, allowing solely using what we need. + # + # Thus, an invariant is: + # + # chunks_n_threads <= effective_n_threads + # + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + + ITYPE_t n_samples_chunk, chunk_size + + ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk + ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk + + bint execute_in_parallel_on_Y + def __init__( self, DatasetsPair datasets_pair, @@ -348,7 +798,8 @@ cdef class PairwiseDistancesReduction: X_end = X_start + self.X_n_samples_chunk # Reinitializing thread datastructures for the new X chunk - self._parallel_on_X_init_chunk(thread_num, X_start) + # Eventually upcast X[X_start:X_end] to 64bit + self._parallel_on_X_init_chunk(thread_num, X_start, X_end) for Y_chunk_idx in range(self.Y_n_chunks): Y_start = Y_chunk_idx * self.Y_n_samples_chunk @@ -357,6 +808,13 @@ cdef class PairwiseDistancesReduction: else: Y_end = Y_start + self.Y_n_samples_chunk + # Eventually upcast Y[Y_start:Y_end] to 64bit + self._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self._compute_and_reduce_distances_on_chunks( X_start, X_end, Y_start, Y_end, @@ -409,7 +867,8 @@ cdef class PairwiseDistancesReduction: thread_num = _openmp_thread_num() # Initializing datastructures used in this thread - self._parallel_on_Y_parallel_init(thread_num) + # Eventually upcast X[X_start:X_end] to 64bit + self._parallel_on_Y_parallel_init(thread_num, X_start, X_end) for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): Y_start = Y_chunk_idx * self.Y_n_samples_chunk @@ -418,6 +877,13 @@ cdef class PairwiseDistancesReduction: else: Y_end = Y_start + self.Y_n_samples_chunk + # Eventually upcast Y[Y_start:Y_end] to 64bit + self._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self._compute_and_reduce_distances_on_chunks( X_start, X_end, Y_start, Y_end, @@ -450,8 +916,9 @@ cdef class PairwiseDistancesReduction: ) nogil: """Compute the pairwise distances on two chunks of X and Y and reduce them. - This is THE core computational method of PairwiseDistanceReductions. - This must be implemented in subclasses. + This is THE core computational method of PairwiseDistanceReductions64. + This must be implemented in subclasses agnostically from the parallelisation + strategies. """ return @@ -479,10 +946,25 @@ cdef class PairwiseDistancesReduction: self, ITYPE_t thread_num, ITYPE_t X_start, + ITYPE_t X_end, ) nogil: """Initialise datastructures used in a thread given its number.""" return + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. + + This is eventually used to upcast X[X_start:X_end] to 64bit. + """ + return + cdef void _parallel_on_X_prange_iter_finalize( self, ITYPE_t thread_num, @@ -508,8 +990,24 @@ cdef class PairwiseDistancesReduction: cdef void _parallel_on_Y_parallel_init( self, ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, ) nogil: - """Initialise datastructures used in a thread given its number.""" + """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. + + This is eventually used to upcast Y[Y_start:Y_end] to 64bit. + """ return cdef void _parallel_on_Y_synchronize( @@ -526,28 +1024,8 @@ cdef class PairwiseDistancesReduction: """Update datastructures after executing all the reductions.""" return -cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): - """Compute the argkmin of row vectors of X on the ones of Y. - - For each row vector of X, computes the indices of k first the rows - vectors of Y with the smallest distances. - - PairwiseDistancesArgKmin is typically used to perform - bruteforce k-nearest neighbors queries. - - Parameters - ---------- - datasets_pair: DatasetsPair - The dataset pairs (X, Y) for the reduction. - - chunk_size: int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - k: int, default=1 - The k for the argkmin reduction. - """ +cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistancesArgKmin.""" cdef: ITYPE_t k @@ -644,14 +1122,13 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): Notes ----- This public classmethod is responsible for introspecting the arguments - values to dispatch to the private :meth:`PairwiseDistancesArgKmin._compute` - instance method of the most appropriate :class:`PairwiseDistancesArgKmin` - concrete implementation. + values to dispatch to the most appropriate concrete implementation + of :class:`PairwiseDistancesArgKmin64`. All temporarily allocated datastructures necessary for the concrete implementation are therefore freed when this classmethod returns. - This allows entirely decoupling the interface entirely from the + This allows decoupling the interface entirely from the implementation details whilst maintaining RAII. """ # Note (jjerphan): Some design thoughts for future extensions. @@ -669,7 +1146,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): # at time to leverage a call to the BLAS GEMM routine as explained # in more details in the docstring. use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesArgKmin( + pda = FastEuclideanPairwiseDistancesArgKmin64( X=X, Y=Y, k=k, use_squared_distances=use_squared_distances, chunk_size=chunk_size, @@ -679,7 +1156,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): else: # Fall back on a generic implementation that handles most scipy # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesArgKmin( + pda = PairwiseDistancesArgKmin64( datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), k=k, chunk_size=chunk_size, @@ -726,7 +1203,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): sizeof(ITYPE_t *) * self.chunks_n_threads ) - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`. + # Main heaps which will be returned as results by `PairwiseDistancesArgKmin64.compute`. self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) @@ -764,11 +1241,11 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): Y_start + j, ) - @final cdef void _parallel_on_X_init_chunk( self, ITYPE_t thread_num, ITYPE_t X_start, + ITYPE_t X_end, ) nogil: # As this strategy is embarrassingly parallel, we can set each # thread's heaps pointer to the proper position on the main heaps. @@ -819,10 +1296,11 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): heaps_size * sizeof(ITYPE_t) ) - @final cdef void _parallel_on_Y_parallel_init( self, ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, ) nogil: # Initialising heaps (memset can't be used here) for idx in range(self.X_n_samples_chunk * self.k): @@ -899,125 +1377,17 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): # Values are returned identically to the way `KNeighborsMixin.kneighbors` # returns values. This is counter-intuitive but this allows not using - # complex adaptations where `PairwiseDistancesArgKmin.compute` is called. + # complex adaptations where `PairwiseDistancesArgKmin64.compute` is called. return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) return np.asarray(self.argkmin_indices) -cdef class GEMMTermComputer: - """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. - - `FastEuclidean*` classes internally compute the squared Euclidean distances between - chunks of vectors X_c and Y_c using using the decomposition: - - - ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - - - This helper class is in charge of wrapping the common logic to compute - the middle term `- 2 X_c_i.Y_c_j^T` with a call to GEMM, which has a high - arithmetic intensity. - """ - +cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): + """Fast specialized alternative for PairwiseDistancesArgKmin64 on EuclideanDistance.""" cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y - - ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - ITYPE_t dist_middle_terms_chunks_size - - # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM - vector[vector[DTYPE_t]] dist_middle_terms_chunks - - def __init__(self, - DTYPE_t[:, ::1] X, - DTYPE_t[:, ::1] Y, - ITYPE_t effective_n_threads, - ITYPE_t chunks_n_threads, - ITYPE_t dist_middle_terms_chunks_size, - ): - self.X = X - self.Y = Y - self.effective_n_threads = effective_n_threads - self.chunks_n_threads = chunks_n_threads - self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size - - self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) - - cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: - self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) - - cdef void _parallel_on_Y_init(self) nogil: - for thread_num in range(self.chunks_n_threads): - self.dist_middle_terms_chunks[thread_num].resize( - self.dist_middle_terms_chunks_size - ) - - cdef DTYPE_t * _compute_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - - const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] - const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() - - # Careful: LDA, LDB and LDC are given for F-ordered arrays - # in BLAS documentations, for instance: - # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa - # - # Here, we use their counterpart values to work with C-ordered arrays. - BLAS_Order order = RowMajor - BLAS_Trans ta = NoTrans - BLAS_Trans tb = Trans - ITYPE_t m = X_c.shape[0] - ITYPE_t n = Y_c.shape[0] - ITYPE_t K = X_c.shape[1] - DTYPE_t alpha = - 2. - # Casting for A and B to remove the const is needed because APIs exposed via - # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * A = &X_c[0, 0] - ITYPE_t lda = X_c.shape[1] - DTYPE_t * B = &Y_c[0, 0] - ITYPE_t ldb = X_c.shape[1] - DTYPE_t beta = 0. - ITYPE_t ldc = Y_c.shape[0] - - # dist_middle_terms = `-2 * X_c @ Y_c.T` - _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) - - return dist_middle_terms + GEMMTermComputer64 gemm_term_computer - -cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): - """Fast specialized variant for PairwiseDistancesArgKmin on EuclideanDistance. - - The full pairwise squared distances matrix is computed as follows: - - ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||² - - The middle term gets computed efficiently below using BLAS Level 3 GEMM. - - Notes - ----- - This implementation has a superior arithmetic intensity and hence - better running time when the variant is IO bound, but it can suffer - from numerical instability caused by catastrophic cancellation potentially - introduced by the subtraction in the arithmetic expression above. - """ - - cdef: - GEMMTermComputer gemm_term_computer const DTYPE_t[::1] X_norm_squared const DTYPE_t[::1] Y_norm_squared @@ -1025,7 +1395,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): @classmethod def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and + return (PairwiseDistancesArgKmin64.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) def __init__( @@ -1045,7 +1415,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case ({self.__class__.__name__}) and will be ignored.", + f"usable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.", UserWarning, stacklevel=3, ) @@ -1059,50 +1429,118 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair cdef: - DenseDenseDatasetsPair datasets_pair = self.datasets_pair + DenseDenseDatasetsPair datasets_pair = ( + self.datasets_pair + ) ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - self.gemm_term_computer = GEMMTermComputer( + self.gemm_term_computer = GEMMTermComputer64( datasets_pair.X, datasets_pair.Y, self.effective_n_threads, self.chunks_n_threads, dist_middle_terms_chunks_size, + n_features=datasets_pair.X.shape[1], + chunk_size=self.chunk_size, ) if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: - self.Y_norm_squared = _sqeuclidean_row_norms(datasets_pair.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms(datasets_pair.X, self.effective_n_threads) + _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances @final cdef void compute_exact_distances(self) nogil: if not self.use_squared_distances: - PairwiseDistancesArgKmin.compute_exact_distances(self) + PairwiseDistancesArgKmin64.compute_exact_distances(self) @final cdef void _parallel_on_X_parallel_init( self, ITYPE_t thread_num, ) nogil: - PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num) + PairwiseDistancesArgKmin64._parallel_on_X_parallel_init(self, thread_num) self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) + + @final cdef void _parallel_on_Y_init( self, ) nogil: cdef ITYPE_t thread_num - PairwiseDistancesArgKmin._parallel_on_Y_init(self) + PairwiseDistancesArgKmin64._parallel_on_Y_init(self) self.gemm_term_computer._parallel_on_Y_init() + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + @final cdef void _compute_and_reduce_distances_on_chunks( self, @@ -1145,30 +1583,8 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) -cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): - """Compute radius-based neighbors for two sets of vectors. - - For each row-vector X[i] of the queries X, find all the indices j of - row-vectors in Y such that: - - dist(X[i], Y[j]) <= radius - - The distance function `dist` depends on the values of the `metric` - and `metric_kwargs` parameters. - - Parameters - ---------- - datasets_pair: DatasetsPair - The dataset pair (X, Y) for the reduction. - - chunk_size: int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - radius: float - The radius defining the neighborhood. - """ +cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistancesArgKmin.""" cdef: DTYPE_t radius @@ -1294,17 +1710,15 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): Notes ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private - :meth:`PairwiseDistancesRadiusNeighborhood._compute` instance method of - the most appropriate :class:`PairwiseDistancesRadiusNeighborhood` - concrete implementation. + This public classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate concrete implementation + of :class:`PairwiseDistancesRadiusNeighborhood64`. - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. - This allows entirely decoupling the interface entirely from the - implementation details whilst maintaining RAII. + This allows entirely decoupling the interface entirely from the + implementation details whilst maintaining RAII. """ # Note (jjerphan): Some design thoughts for future extensions. # This factory comes to handle specialisations for the given arguments. @@ -1321,7 +1735,7 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # at time to leverage a call to the BLAS GEMM routine as explained # in more details in the docstring. use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesRadiusNeighborhood( + pda = FastEuclideanPairwiseDistancesRadiusNeighborhood64( X=X, Y=Y, radius=radius, use_squared_distances=use_squared_distances, chunk_size=chunk_size, @@ -1332,7 +1746,7 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): else: # Fall back on a generic implementation that handles most scipy # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesRadiusNeighborhood( + pda = PairwiseDistancesRadiusNeighborhood64( datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), radius=radius, chunk_size=chunk_size, @@ -1423,11 +1837,11 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): return coerce_vectors_to_nd_arrays(self.neigh_indices) - @final cdef void _parallel_on_X_init_chunk( self, ITYPE_t thread_num, ITYPE_t X_start, + ITYPE_t X_end, ) nogil: # As this strategy is embarrassingly parallel, we can set the @@ -1546,26 +1960,11 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): ) -cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRadiusNeighborhood): - """Fast specialized variant for PairwiseDistancesRadiusNeighborhood on EuclideanDistance. - - The full pairwise squared distances matrix is computed as follows: - - ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||² - - The middle term gets computed efficiently below using BLAS Level 3 GEMM. - - Notes - ----- - This implementation has a superior arithmetic intensity and hence - better running time when the variant is IO bound, but it can suffer - from numerical instability caused by catastrophic cancellation potentially - introduced by the subtraction in the arithmetic expression above. - numerical precision is needed. - """ +cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): + """Fast specialized variant for PairwiseDistancesRadiusNeighborhood on EuclideanDistance.""" cdef: - GEMMTermComputer gemm_term_computer + GEMMTermComputer64 gemm_term_computer const DTYPE_t[::1] X_norm_squared const DTYPE_t[::1] Y_norm_squared @@ -1573,7 +1972,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad @classmethod def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesRadiusNeighborhood.is_usable_for(X, Y, metric) + return (PairwiseDistancesRadiusNeighborhood64.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) def __init__( @@ -1594,7 +1993,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case ({self.__class__.__name__}) and will be ignored.", + f"usable for this case (FastEuclideanPairwiseDistancesRadiusNeighborhood) and will be ignored.", UserWarning, stacklevel=3, ) @@ -1613,23 +2012,25 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad DenseDenseDatasetsPair datasets_pair = self.datasets_pair ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - self.gemm_term_computer = GEMMTermComputer( + self.gemm_term_computer = GEMMTermComputer64( datasets_pair.X, datasets_pair.Y, self.effective_n_threads, self.chunks_n_threads, dist_middle_terms_chunks_size, + n_features=datasets_pair.X.shape[1], + chunk_size=self.chunk_size, ) if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: - self.Y_norm_squared = _sqeuclidean_row_norms(datasets_pair.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms(datasets_pair.X, self.effective_n_threads) + _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances @@ -1638,27 +2039,85 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad # already considered to be the adapted radius, so we overwrite it. self.r_radius = radius - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - PairwiseDistancesRadiusNeighborhood.compute_exact_distances(self) - @final cdef void _parallel_on_X_parallel_init( self, ITYPE_t thread_num, ) nogil: - PairwiseDistancesRadiusNeighborhood._parallel_on_X_parallel_init(self, thread_num) + PairwiseDistancesRadiusNeighborhood64._parallel_on_X_parallel_init(self, thread_num) self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) + @final cdef void _parallel_on_Y_init( self, ) nogil: cdef ITYPE_t thread_num - PairwiseDistancesRadiusNeighborhood._parallel_on_Y_init(self) + PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_init(self) self.gemm_term_computer._parallel_on_Y_init() + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesRadiusNeighborhood64.compute_exact_distances(self) + @final cdef void _compute_and_reduce_distances_on_chunks( self, From c772d48c7d7ff527839381be4b2b36a08f3d3166 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 7 Jun 2022 18:14:44 +0200 Subject: [PATCH 02/19] DEBUG Propagate sort_results --- sklearn/metrics/_pairwise_distances_reduction.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index abbcaaf8b5678..b409871b8cffe 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -499,6 +499,7 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): chunk_size=chunk_size, metric_kwargs=metric_kwargs, strategy=strategy, + sort_results=sort_results, return_distance=return_distance, ) raise ValueError( From 590a8b05141729db01049b6a64d1a39d270d3cd5 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 8 Jun 2022 16:49:14 +0200 Subject: [PATCH 03/19] DOC Reword --- .../metrics/_pairwise_distances_reduction.pyx | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index b409871b8cffe..7c8688dc57eff 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -114,19 +114,22 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( return squared_row_norms ##################### -# Interfaces: -# Those interfaces are meant to be used in the Python code, decoupling the -# actual implementation from the Python code. This allows changing all the -# private implementation while maintaining a contract for the Python callers. +# Dispatcher: # -# Each interface extending the base `PairwiseDistancesReduction` interface must -# implement the :meth:`compute` classmethod. +# Those dispatchers are meant to be used in the Python code, decoupling the +# actual implementations from the Python code. This allows changing all the +# private implementations while maintaining the same contract for the +# Python callers. +# +# Each dispatcher extending the base `PairwiseDistancesReduction` dispatcher +# must implement the :meth:`compute` classmethod. # # Under the hood, such a function must only define the logic to dispatch # at runtime to the correct dtype-specialized `PairwiseDistancesReduction` # implementation based on the dtype of X and of Y. +# -# Base interface +# Base dispatcher cdef class PairwiseDistancesReduction: """Abstract base class for pairwise distance computation & reduction. @@ -347,7 +350,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): All temporarily allocated datastructures necessary for the concrete implementation are therefore freed when this classmethod returns. - This allows decoupling the interface entirely from the + This allows decoupling the API entirely from the implementation details whilst maintaining RAII. """ if X.dtype == Y.dtype == np.float64: @@ -487,7 +490,7 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): All temporarily allocated datastructures necessary for the concrete implementation are therefore freed when this classmethod returns. - This allows entirely decoupling the interface entirely from the + This allows entirely decoupling the API entirely from the implementation details whilst maintaining RAII. """ if X.dtype == Y.dtype == np.float64: @@ -508,12 +511,12 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): ) ##################### -# dtype-specific implementations: -# For each dtype, an implementation of `PairwiseDistancesReductions` are generated by Tempita. -# Computations are dispatched to them at runtime via the interfaces defined above. +# dtype-specialized implementations: +# +# For each dtype, an implementation of `PairwiseDistancesReductions` is generated +# by Tempita. Computations are dispatched to them at runtime via the dispatchers +# defined above. Other helpers are also made dtype-specialized. # -# Also, other helper are dtype-specialised. - cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( const DTYPE_t[:, ::1] X, ITYPE_t num_threads, @@ -1129,7 +1132,7 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): All temporarily allocated datastructures necessary for the concrete implementation are therefore freed when this classmethod returns. - This allows decoupling the interface entirely from the + This allows decoupling the API entirely from the implementation details whilst maintaining RAII. """ # Note (jjerphan): Some design thoughts for future extensions. @@ -1385,10 +1388,9 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): - """Fast specialized alternative for PairwiseDistancesArgKmin64 on EuclideanDistance.""" + """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" cdef: GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared const DTYPE_t[::1] Y_norm_squared @@ -1599,7 +1601,7 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): # Neighbors indices and distances are returned as np.ndarrays of np.ndarrays. # # For this implementation, we want resizable buffers which we will wrap - # into numpy arrays at the end. std::vector comes as a handy interface + # into numpy arrays at the end. std::vector comes as a handy container # for interacting efficiently with resizable buffers. # # Though it is possible to access their buffer address with @@ -1718,7 +1720,7 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): All temporarily allocated datastructures necessary for the concrete implementation are therefore freed when this classmethod returns. - This allows entirely decoupling the interface entirely from the + This allows entirely decoupling the API entirely from the implementation details whilst maintaining RAII. """ # Note (jjerphan): Some design thoughts for future extensions. @@ -1962,8 +1964,7 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): - """Fast specialized variant for PairwiseDistancesRadiusNeighborhood on EuclideanDistance.""" - + """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" cdef: GEMMTermComputer64 gemm_term_computer const DTYPE_t[::1] X_norm_squared From ec40fea6026f8f956312f76b2b5dd71ae5ef8db3 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 9 Jun 2022 15:12:00 +0200 Subject: [PATCH 04/19] DOC Document dispatchers and implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with some ASCII art. 🎨 Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction.pyx | 448 ++++++------------ 1 file changed, 142 insertions(+), 306 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 7c8688dc57eff..ea1feab6ccf41 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -3,16 +3,87 @@ # # Author: Julien Jerphanion # +# Overview +# -------- # -# The abstractions defined here are used in various algorithms performing -# the same structure of operations on distances between row vectors -# of a datasets pair (X, Y). +# This module provides routines to compute pairwise distances between a set +# of row vectors of X and another set of row vectors of Y and apply a +# reduction on top. # -# Importantly, the core of the computation is chunked to make sure that the pairwise -# distance chunk matrices stay in CPU cache before applying the final reduction step. -# Furthermore, the chunking strategy is also used to leverage OpenMP-based parallelism -# (using Cython prange loops) which gives another multiplicative speed-up in -# favorable cases on many-core machines. +# The reduction takes a matrix of pairwise distances between rows of X and Y +# as input and outputs an aggregate data-structure for each row of X. The +# aggregate values are typically smaller than the number of rows in Y, hence +# the term reduction. +# +# For computational reasons, the reduction are performed on the fly on chunks +# of rows of X and Y so as to keep intermediate data-structures in CPU cache +# and avoid unnecessary round trips of large distance arrays with the RAM +# that would otherwise severely degrade the speed by making the overall +# processing memory-bound. +# +# Finally, the routines follow a generic parallelization template to process +# chunks of data with OpenMP loops (via Cython prange), either on rows of X +# or rows of Y depending on their respective sizes. +# +# +# Dispatcher and implementations +# ------------------------------ +# +# Dispatchers are meant to be used in the Python code. Under the hood, a +# dispatch must only define the logic to choose at runtime to the correct +# dtype-specialized :class:`PairwiseDistancesReduction` implementation based +# on the dtype of X and of Y. +# +# +# High-level diagrams +# ------------------- +# +# Legend: +# +# A ---⊳ B: A inherits from B +# A ---x B: A dispatches on B +# A ---* B: A composes B +# +# +# (base dispatcher) +# PairwiseDistancesReduction +# ∆ +# | +# | +# +-----------------+-----------------+ +# | | +# (dispatcher) (dispatcher) +# PairwiseDistancesArgKmin PairwiseDistancesRadiusNeighbors +# | | +# | | +# | | +# | (64bit implem.) | +# | PairwiseDistancesReduction64 | +# | ∆ | +# | | | +# | | | +# | +-----------------+-----------------+ | +# | | | | +# | | | | +# x | | x +# PairwiseDistancesArgKmin64 PairwiseDistancesRadiusNeighbors64 +# | ∆ ∆ | +# | | | | +# x | | | +# FastEuclideanPairwiseDistancesArgKmin64 | | +# * | | +# | | x +# | FastEuclideanPairwiseDistancesRadiusNeighbors64 +# | * +# | | +# +-----------------+-----------------+ +# | +# GEMMTermComputer64 +# +# Finally, the implementation might dispatch to a specialised implementation +# for the Euclidean Distance case using the Generalized Matrix Multiplication +# (see :class:`GEMMTermComputer64`). + cimport numpy as cnp import numpy as np import warnings @@ -114,85 +185,13 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( return squared_row_norms ##################### -# Dispatcher: -# -# Those dispatchers are meant to be used in the Python code, decoupling the -# actual implementations from the Python code. This allows changing all the -# private implementations while maintaining the same contract for the -# Python callers. -# -# Each dispatcher extending the base `PairwiseDistancesReduction` dispatcher -# must implement the :meth:`compute` classmethod. -# -# Under the hood, such a function must only define the logic to dispatch -# at runtime to the correct dtype-specialized `PairwiseDistancesReduction` -# implementation based on the dtype of X and of Y. -# +# Dispatchers -# Base dispatcher cdef class PairwiseDistancesReduction: - """Abstract base class for pairwise distance computation & reduction. - - Subclasses of this class compute pairwise distances between a set of - row vectors of X and another set of row vectors of Y and apply a reduction on top. - The reduction takes a matrix of pairwise distances between rows of X and Y - as input and outputs an aggregate data-structure for each row of X. - The aggregate values are typically smaller than the number of rows in Y, - hence the term reduction. - - For computational reasons, it is interesting to perform the reduction on - the fly on chunks of rows of X and Y so as to keep intermediate - data-structures in CPU cache and avoid unnecessary round trips of large - distance arrays with the RAM that would otherwise severely degrade the - speed by making the overall processing memory-bound. - - The base class provides the generic chunked parallelization template using - OpenMP loops (Cython prange), either on rows of X or rows of Y depending on - their respective sizes. - - The subclasses are specialized for reduction. - - The actual distance computation for a given pair of rows of X and Y are - delegated to format-specific subclasses of the DatasetsPair companion base - class. - - Parameters - ---------- - datasets_pair: DatasetsPair - The pair of dataset to use. - - chunk_size: int, default=None - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + """Abstract base dispatcher for pairwise distance computation & reduction. + + Each dispatcher extending the base :class:`PairwiseDistancesReduction` + dispatcher must implement the :meth:`compute` classmethod. """ @classmethod @@ -211,7 +210,8 @@ cdef class PairwiseDistancesReduction: @classmethod def is_usable_for(cls, X, Y, metric) -> bool: - """Return True if the PairwiseDistancesReduction can be used for the given parameters. + """Return True if the PairwiseDistancesReduction can be used for the + given parameters. Parameters ---------- @@ -245,18 +245,9 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): PairwiseDistancesArgKmin is typically used to perform bruteforce k-nearest neighbors queries. - Parameters - ---------- - datasets_pair: DatasetsPair - The dataset pairs (X, Y) for the reduction. - - chunk_size: int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - k: int, default=1 - The k for the argkmin reduction. + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. """ @classmethod @@ -271,7 +262,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): str strategy=None, bint return_distance=False, ): - """Return the results of the reduction for the given arguments. + """Compute the argkmin reduction. Parameters ---------- @@ -331,27 +322,26 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): Returns ------- - If return_distance=False: - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. + If return_distance=False: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. - If return_distance=True: - - argkmin_distances : ndarray of shape (n_samples_X, k) - Distances to the argkmin for each vector in X. - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. + If return_distance=True: + - argkmin_distances : ndarray of shape (n_samples_X, k) + Distances to the argkmin for each vector in X. + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. Notes ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private dtype-specialized implementation of - :class:`PairwiseDistancesArgKmin`. - - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. - - This allows decoupling the API entirely from the - implementation details whilst maintaining RAII. + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesArgKmin`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. """ if X.dtype == Y.dtype == np.float64: return PairwiseDistancesArgKmin64.compute( @@ -381,18 +371,9 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): The distance function `dist` depends on the values of the `metric` and `metric_kwargs` parameters. - Parameters - ---------- - datasets_pair: DatasetsPair - The dataset pair (X, Y) for the reduction. - - chunk_size: int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - radius: float - The radius defining the neighborhood. + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. """ @classmethod @@ -483,15 +464,15 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): Notes ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private dtype-specialized implementation of - :class:`PairwiseDistancesRadiusNeighborhood`. + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistancesRadiusNeighborhood`. - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. - This allows entirely decoupling the API entirely from the - implementation details whilst maintaining RAII. + This allows entirely decoupling the API entirely from the + implementation details whilst maintaining RAII. """ if X.dtype == Y.dtype == np.float64: return PairwiseDistancesRadiusNeighborhood64.compute( @@ -511,12 +492,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): ) ##################### -# dtype-specialized implementations: -# -# For each dtype, an implementation of `PairwiseDistancesReductions` is generated -# by Tempita. Computations are dispatched to them at runtime via the dispatchers -# defined above. Other helpers are also made dtype-specialized. -# +# dtype-specialized implementations + cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( const DTYPE_t[:, ::1] X, ITYPE_t num_threads, @@ -675,19 +652,19 @@ cdef class GEMMTermComputer64: return dist_middle_terms -cdef class PairwiseDistancesReduction64(PairwiseDistancesReduction): - """64bit implementation of PairwiseDistancesReduction.""" +cdef class PairwiseDistancesReduction64: + """Base 64bit implementation of PairwiseDistancesReduction.""" cdef: readonly DatasetsPair datasets_pair # The number of threads that can be used is stored in effective_n_threads. # - # The number of threads to use in the parallelisation strategy + # The number of threads to use in the parallelization strategy # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: - # for small datasets, less threads might be needed to loop over pair of chunks. + # for small datasets, fewer threads might be needed to loop over pair of chunks. # - # Hence the number of threads that _will_ be used for looping over chunks + # Hence, the number of threads that _will_ be used for looping over chunks # is stored in chunks_n_threads, allowing solely using what we need. # # Thus, an invariant is: @@ -921,7 +898,7 @@ cdef class PairwiseDistancesReduction64(PairwiseDistancesReduction): """Compute the pairwise distances on two chunks of X and Y and reduce them. This is THE core computational method of PairwiseDistanceReductions64. - This must be implemented in subclasses agnostically from the parallelisation + This must be implemented in subclasses agnostically from the parallelization strategies. """ return @@ -1053,87 +1030,18 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): str strategy=None, bint return_distance=False, ): - """Return the results of the reduction for the given arguments. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - k : int - The k for the argkmin reduction. + """Compute the argkmin reduction. - metric : str, default='euclidean' - The distance metric to use for argkmin. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesArgKmin64`. - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its - argkmin if set to True. - - Returns - ------- - If return_distance=False: - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - If return_distance=True: - - argkmin_distances : ndarray of shape (n_samples_X, k) - Distances to the argkmin for each vector in X. - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - Notes - ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate concrete implementation - of :class:`PairwiseDistancesArgKmin64`. - - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. - - This allows decoupling the API entirely from the - implementation details whilst maintaining RAII. + No instance should directly be created outside of this class method. """ # Note (jjerphan): Some design thoughts for future extensions. # This factory comes to handle specialisations for the given arguments. @@ -1638,90 +1546,18 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): bint return_distance=False, bint sort_results=False, ): - """Return the results of the reduction for the given arguments. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - radius : float - The radius defining the neighborhood. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: + """Compute the radius-neighbors reduction. - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its neighbors if set to True. - - sort_results : boolean, default=False - Sort results with respect to distances between each X vector and its - neighbors if set to True. - - Returns - ------- - If return_distance=False: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - If return_distance=True: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - neighbors_distances : ndarray of n_samples_X ndarray - Distances to the neighbors for each vector in X. - - Notes - ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate concrete implementation - of :class:`PairwiseDistancesRadiusNeighborhood64`. + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesRadiusNeighborhood64`. - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. - This allows entirely decoupling the API entirely from the - implementation details whilst maintaining RAII. + No instance should directly be created outside of this class method. """ # Note (jjerphan): Some design thoughts for future extensions. # This factory comes to handle specialisations for the given arguments. From cd49d8a17720c6f0206dfd357e46608c92a62fab Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 10 Jun 2022 11:10:27 +0200 Subject: [PATCH 05/19] MAINT Simply use python class for dispatchers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- sklearn/metrics/_pairwise_distances_reduction.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index ea1feab6ccf41..72950970c01e1 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -187,7 +187,7 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( ##################### # Dispatchers -cdef class PairwiseDistancesReduction: +class PairwiseDistancesReduction: """Abstract base dispatcher for pairwise distance computation & reduction. Each dispatcher extending the base :class:`PairwiseDistancesReduction` @@ -236,7 +236,7 @@ cdef class PairwiseDistancesReduction: metric in cls.valid_metrics()) -cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): +class PairwiseDistancesArgKmin(PairwiseDistancesReduction): """Compute the argkmin of row vectors of X on the ones of Y. For each row vector of X, computes the indices of k first the rows @@ -360,7 +360,7 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): ) -cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): +class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): """Compute radius-based neighbors for two sets of vectors. For each row-vector X[i] of the queries X, find all the indices j of From c7dc9871874d6fb7d01563962914c4a8acc7b838 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 10 Jun 2022 16:18:58 +0200 Subject: [PATCH 06/19] DOC Improve comments Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction.pyx | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 72950970c01e1..702a6d1c8f712 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -8,7 +8,8 @@ # # This module provides routines to compute pairwise distances between a set # of row vectors of X and another set of row vectors of Y and apply a -# reduction on top. +# reduction on top. The canonical example is the brute-force computation +# of the top k nearest neighbors by leveraging the arg-k-min reduction. # # The reduction takes a matrix of pairwise distances between rows of X and Y # as input and outputs an aggregate data-structure for each row of X. The @@ -26,17 +27,17 @@ # or rows of Y depending on their respective sizes. # # -# Dispatcher and implementations -# ------------------------------ +# Dispatching to specialized implementations +# ------------------------------------------ # # Dispatchers are meant to be used in the Python code. Under the hood, a -# dispatch must only define the logic to choose at runtime to the correct +# dispatcher must only define the logic to choose at runtime to the correct # dtype-specialized :class:`PairwiseDistancesReduction` implementation based # on the dtype of X and of Y. # # -# High-level diagrams -# ------------------- +# High-level diagram +# ------------------ # # Legend: # @@ -80,9 +81,15 @@ # | # GEMMTermComputer64 # -# Finally, the implementation might dispatch to a specialised implementation -# for the Euclidean Distance case using the Generalized Matrix Multiplication -# (see :class:`GEMMTermComputer64`). +# For instance :class:`PairwiseDistancesArgKmin`, dispatches to +# :class:`PairwiseDistancesArgKmin64` if X and Y are both dense NumPy arrays +# with a float64 dtype. +# +# In addition, if the metric parameter is set to "sqeuclidean", +# :class:`PairwiseDistancesArgKmin64` further dispatches to a subclass +# specialized to optimally handle the Euclidean distance case using the +# Generalized Matrix Multiplication (see the docstring of +# :class:`GEMMTermComputer64` for details). cimport numpy as cnp import numpy as np @@ -779,7 +786,7 @@ cdef class PairwiseDistancesReduction64: X_end = X_start + self.X_n_samples_chunk # Reinitializing thread datastructures for the new X chunk - # Eventually upcast X[X_start:X_end] to 64bit + # If necessary, upcast X[X_start:X_end] to 64bit self._parallel_on_X_init_chunk(thread_num, X_start, X_end) for Y_chunk_idx in range(self.Y_n_chunks): @@ -789,7 +796,7 @@ cdef class PairwiseDistancesReduction64: else: Y_end = Y_start + self.Y_n_samples_chunk - # Eventually upcast Y[Y_start:Y_end] to 64bit + # If necessary, upcast Y[Y_start:Y_end] to 64bit self._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( X_start, X_end, Y_start, Y_end, @@ -848,7 +855,7 @@ cdef class PairwiseDistancesReduction64: thread_num = _openmp_thread_num() # Initializing datastructures used in this thread - # Eventually upcast X[X_start:X_end] to 64bit + # If necessary, upcast X[X_start:X_end] to 64bit self._parallel_on_Y_parallel_init(thread_num, X_start, X_end) for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): @@ -858,7 +865,7 @@ cdef class PairwiseDistancesReduction64: else: Y_end = Y_start + self.Y_n_samples_chunk - # Eventually upcast Y[Y_start:Y_end] to 64bit + # If necessary, upcast Y[Y_start:Y_end] to 64bit self._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( X_start, X_end, Y_start, Y_end, From e6f9c9aa5cf6f6791a2d4e352e0aa8f1e3d7f228 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 14 Jun 2022 09:57:15 +0200 Subject: [PATCH 07/19] Apply review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger Co-authored-by: Thomas J. Fan Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction.pyx | 175 +++++++++--------- .../test_pairwise_distances_reduction.py | 4 +- 2 files changed, 87 insertions(+), 92 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 702a6d1c8f712..7a8b763541721 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -43,53 +43,48 @@ # # A ---⊳ B: A inherits from B # A ---x B: A dispatches on B -# A ---* B: A composes B # # -# (base dispatcher) -# PairwiseDistancesReduction -# ∆ -# | -# | -# +-----------------+-----------------+ -# | | -# (dispatcher) (dispatcher) -# PairwiseDistancesArgKmin PairwiseDistancesRadiusNeighbors -# | | -# | | -# | | -# | (64bit implem.) | -# | PairwiseDistancesReduction64 | -# | ∆ | -# | | | -# | | | -# | +-----------------+-----------------+ | -# | | | | -# | | | | -# x | | x -# PairwiseDistancesArgKmin64 PairwiseDistancesRadiusNeighbors64 -# | ∆ ∆ | -# | | | | -# x | | | -# FastEuclideanPairwiseDistancesArgKmin64 | | -# * | | -# | | x -# | FastEuclideanPairwiseDistancesRadiusNeighbors64 -# | * -# | | -# +-----------------+-----------------+ -# | -# GEMMTermComputer64 +# (base dispatcher) +# PairwiseDistancesReduction +# ∆ +# | +# | +# +-----------------+-----------------+ +# | | +# (dispatcher) (dispatcher) +# PairwiseDistancesArgKmin PairwiseDistancesRadiusNeighbors +# | | +# | | +# | | +# | (64bit implem.) | +# | PairwiseDistancesReduction64 | +# | ∆ | +# | | | +# | | | +# | +-----------------+-----------------+ | +# | | | | +# | | | | +# x | | x +# PairwiseDistancesArgKmin64 PairwiseDistancesRadiusNeighbors64 +# | ∆ ∆ | +# | | | | +# x | | | +# FastEuclideanPairwiseDistancesArgKmin64 | | +# | | +# | x +# FastEuclideanPairwiseDistancesRadiusNeighbors64 # -# For instance :class:`PairwiseDistancesArgKmin`, dispatches to -# :class:`PairwiseDistancesArgKmin64` if X and Y are both dense NumPy arrays -# with a float64 dtype. +# For instance :class:`PairwiseDistancesArgKmin`, dispatches to +# :class:`PairwiseDistancesArgKmin64` if X and Y are both dense NumPy arrays +# with a float64 dtype. # -# In addition, if the metric parameter is set to "sqeuclidean", -# :class:`PairwiseDistancesArgKmin64` further dispatches to a subclass -# specialized to optimally handle the Euclidean distance case using the -# Generalized Matrix Multiplication (see the docstring of -# :class:`GEMMTermComputer64` for details). +# In addition, if the metric parameter is set to "euclidean" or "sqeuclidean", +# :class:`PairwiseDistancesArgKmin64` further dispatches to +# :class:`FastEuclideanPairwiseDistancesArgKmin64` a specialized subclass +# to optimally handle the Euclidean distance case using the Generalized Matrix +# Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). +from abc import abstractmethod cimport numpy as cnp import numpy as np @@ -165,32 +160,6 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( return nd_arrays_of_nd_arrays -##################### - -cpdef DTYPE_t[::1] _sqeuclidean_row_norms( - const DTYPE_t[:, ::1] X, - ITYPE_t num_threads, -): - """Compute the squared euclidean norm of the rows of X in parallel. - - This is faster than using np.einsum("ij, ij->i") even when using a single thread. - """ - cdef: - # Casting for X to remove the const qualifier is needed because APIs - # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' - # const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * X_ptr = &X[0, 0] - ITYPE_t idx = 0 - ITYPE_t n = X.shape[0] - ITYPE_t d = X.shape[1] - DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) - - for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): - squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) - - return squared_row_norms - ##################### # Dispatchers @@ -242,6 +211,32 @@ class PairwiseDistancesReduction: not issparse(X) and not issparse(Y) and dtypes_validity and metric in cls.valid_metrics()) + @classmethod + @abstractmethod + def compute( + cls, + X, + Y, + **kwargs, + ): + """Compute the reduction. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + **kwargs : additional parameters for the reduction + + Notes + ----- + This method is an abstract class method: it has to be implemented + for all + + """ class PairwiseDistancesArgKmin(PairwiseDistancesReduction): """Compute the argkmin of row vectors of X on the ones of Y. @@ -262,12 +257,12 @@ class PairwiseDistancesArgKmin(PairwiseDistancesReduction): cls, X, Y, - ITYPE_t k, - str metric="euclidean", + k, + metric="euclidean", chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, + metric_kwargs=None, + strategy=None, + return_distance=False, ): """Compute the argkmin reduction. @@ -350,6 +345,11 @@ class PairwiseDistancesArgKmin(PairwiseDistancesReduction): for the concrete implementation are therefore freed when this classmethod returns. """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various backend and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. if X.dtype == Y.dtype == np.float64: return PairwiseDistancesArgKmin64.compute( X=X, @@ -388,13 +388,13 @@ class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): cls, X, Y, - DTYPE_t radius, - str metric="euclidean", + radius, + metric="euclidean", chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - bint sort_results=False, + metric_kwargs=None, + strategy=None, + return_distance=False, + sort_results=False, ): """Return the results of the reduction for the given arguments. @@ -481,6 +481,11 @@ class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): This allows entirely decoupling the API entirely from the implementation details whilst maintaining RAII. """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various backend and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. if X.dtype == Y.dtype == np.float64: return PairwiseDistancesRadiusNeighborhood64.compute( X=X, @@ -1050,11 +1055,6 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): No instance should directly be created outside of this class method. """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. if ( metric in ("euclidean", "sqeuclidean") and not issparse(X) @@ -1566,11 +1566,6 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): No instance should directly be created outside of this class method. """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. if ( metric in ("euclidean", "sqeuclidean") and not issparse(X) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index fa475134c7a9f..b47407f3754ee 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -12,7 +12,7 @@ PairwiseDistancesReduction, PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood, - _sqeuclidean_row_norms, + _sqeuclidean_row_norms64, ) from sklearn.metrics import euclidean_distances @@ -967,6 +967,6 @@ def test_sqeuclidean_row_norms( X = rng.rand(n_samples, n_features).astype(dtype) * spread sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2 - sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads)) + sq_row_norm = np.asarray(_sqeuclidean_row_norms64(X, num_threads=num_threads)) assert_allclose(sq_row_norm_reference, sq_row_norm) From 3a8eb53ad232f3c0abc1f24d4974ffc31418d87c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 9 Jun 2022 15:59:36 +0200 Subject: [PATCH 08/19] MAINT Create private _pairwise_distances_reductions submodule --- sklearn/metrics/_dist_metrics.pxd.tp | 20 - sklearn/metrics/_dist_metrics.pyx.tp | 161 -- .../metrics/_pairwise_distances_reduction.pyx | 1993 ----------------- .../_pairwise_distances_reduction/__init__.py | 102 + .../_argkmin.pxd | 33 + .../_argkmin.pyx | 625 ++++++ .../_pairwise_distances_reduction/_base.pxd | 128 ++ .../_pairwise_distances_reduction/_base.pyx | 424 ++++ .../_datasets_pair.pxd | 21 + .../_datasets_pair.pyx | 164 ++ .../_gemm_term_computer.pxd | 62 + .../_gemm_term_computer.pyx | 135 ++ .../_radius_neighborhood.pxd | 89 + .../_radius_neighborhood.pyx | 638 ++++++ .../_pairwise_distances_reduction/setup.py | 40 + .../tests/__init__.py | 0 .../test_pairwise_distances_reduction.py | 22 +- sklearn/metrics/setup.py | 10 +- sklearn/neighbors/tests/test_neighbors.py | 4 +- sklearn/utils/_testing.py | 19 + 20 files changed, 2484 insertions(+), 2206 deletions(-) delete mode 100644 sklearn/metrics/_pairwise_distances_reduction.pyx create mode 100644 sklearn/metrics/_pairwise_distances_reduction/__init__.py create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_base.pxd create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_base.pyx create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx create mode 100644 sklearn/metrics/_pairwise_distances_reduction/setup.py create mode 100644 sklearn/metrics/_pairwise_distances_reduction/tests/__init__.py rename sklearn/metrics/{ => _pairwise_distances_reduction}/tests/test_pairwise_distances_reduction.py (98%) diff --git a/sklearn/metrics/_dist_metrics.pxd.tp b/sklearn/metrics/_dist_metrics.pxd.tp index 32ba546672c6e..ef23f2af50ffb 100644 --- a/sklearn/metrics/_dist_metrics.pxd.tp +++ b/sklearn/metrics/_dist_metrics.pxd.tp @@ -101,23 +101,3 @@ cdef class DistanceMetric{{name_suffix}}: cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1 {{endfor}} - -###################################################################### -# DatasetsPair base class -cdef class DatasetsPair: - cdef DistanceMetric distance_metric - - cdef ITYPE_t n_samples_X(self) nogil - - cdef ITYPE_t n_samples_Y(self) nogil - - cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil - - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil - - -cdef class DenseDenseDatasetsPair(DatasetsPair): - cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y - ITYPE_t d diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 5986fa939b45d..47bd1dcbab519 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -32,7 +32,6 @@ implementation_specific_values = [ import numpy as np cimport numpy as cnp -from cython cimport final cnp.import_array() # required in order to use C-API @@ -1171,163 +1170,3 @@ cdef class PyFuncDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): "vectors and return a float.") {{endfor}} - -###################################################################### -# Datasets Pair Classes -cdef class DatasetsPair: - """Abstract class which wraps a pair of datasets (X, Y). - - This class allows computing distances between a single pair of rows of - of X and Y at a time given the pair of their indices (i, j). This class is - specialized for each metric thanks to the :func:`get_for` factory classmethod. - - The handling of parallelization over chunks to compute the distances - and aggregation for several rows at a time is done in dedicated - subclasses of PairwiseDistancesReduction that in-turn rely on - subclasses of DatasetsPair for each pair of rows in the data. The goal - is to make it possible to decouple the generic parallelization and - aggregation logic from metric-specific computation as much as - possible. - - X and Y can be stored as C-contiguous np.ndarrays or CSR matrices - in subclasses. - - This class avoids the overhead of dispatching distance computations - to :class:`sklearn.metrics.DistanceMetric` based on the physical - representation of the vectors (sparse vs. dense). It makes use of - cython.final to remove the overhead of dispatching method calls. - - Parameters - ---------- - distance_metric: DistanceMetric - The distance metric responsible for computing distances - between two vectors of (X, Y). - """ - - @classmethod - def get_for( - cls, - X, - Y, - str metric="euclidean", - dict metric_kwargs=None, - ) -> DatasetsPair: - """Return the DatasetsPair implementation for the given arguments. - - Parameters - ---------- - X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) - Input data. - If provided as a ndarray, it must be C-contiguous. - If provided as a sparse matrix, it must be in CSR format. - - Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) - Input data. - If provided as a ndarray, it must be C-contiguous. - If provided as a sparse matrix, it must be in CSR format. - - metric : str, default='euclidean' - The distance metric to compute between rows of X and Y. - The default metric is a fast implementation of the Euclidean - metric. For a list of available metrics, see the documentation - of :class:`~sklearn.metrics.DistanceMetric`. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - Returns - ------- - datasets_pair: DatasetsPair - The suited DatasetsPair implementation. - """ - cdef: - DistanceMetric distance_metric = DistanceMetric.get_metric( - metric, - **(metric_kwargs or {}) - ) - - if not(X.dtype == Y.dtype == np.float64): - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - - # Metric-specific checks that do not replace nor duplicate `check_array`. - distance_metric._validate_data(X) - distance_metric._validate_data(Y) - - # TODO: dispatch to other dataset pairs for sparse support once available: - if issparse(X) or issparse(Y): - raise ValueError("Only dense datasets are supported for X and Y.") - - return DenseDenseDatasetsPair(X, Y, distance_metric) - - def __init__(self, DistanceMetric distance_metric): - self.distance_metric = distance_metric - - cdef ITYPE_t n_samples_X(self) nogil: - """Number of samples in X.""" - # This is a abstract method. - # This _must_ always be overwritten in subclasses. - # TODO: add "with gil: raise" here when supporting Cython 3.0 - return -999 - - cdef ITYPE_t n_samples_Y(self) nogil: - """Number of samples in Y.""" - # This is a abstract method. - # This _must_ always be overwritten in subclasses. - # TODO: add "with gil: raise" here when supporting Cython 3.0 - return -999 - - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: - return self.dist(i, j) - - cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: - # This is a abstract method. - # This _must_ always be overwritten in subclasses. - # TODO: add "with gil: raise" here when supporting Cython 3.0 - return -1 - -@final -cdef class DenseDenseDatasetsPair(DatasetsPair): - """Compute distances between row vectors of two arrays. - - Parameters - ---------- - X: ndarray of shape (n_samples_X, n_features) - Rows represent vectors. Must be C-contiguous. - - Y: ndarray of shape (n_samples_Y, n_features) - Rows represent vectors. Must be C-contiguous. - - distance_metric: DistanceMetric - The distance metric responsible for computing distances - between two row vectors of (X, Y). - """ - - def __init__(self, X, Y, DistanceMetric distance_metric): - super().__init__(distance_metric) - # Arrays have already been checked - self.X = X - self.Y = Y - self.d = X.shape[1] - - @final - cdef ITYPE_t n_samples_X(self) nogil: - return self.X.shape[0] - - @final - cdef ITYPE_t n_samples_Y(self) nogil: - return self.Y.shape[0] - - @final - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: - return self.distance_metric.rdist(&self.X[i, 0], - &self.Y[j, 0], - self.d) - - @final - cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: - return self.distance_metric.dist(&self.X[i, 0], - &self.Y[j, 0], - self.d) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx deleted file mode 100644 index 7a8b763541721..0000000000000 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ /dev/null @@ -1,1993 +0,0 @@ -# Pairwise Distances Reductions -# ============================= -# -# Author: Julien Jerphanion -# -# Overview -# -------- -# -# This module provides routines to compute pairwise distances between a set -# of row vectors of X and another set of row vectors of Y and apply a -# reduction on top. The canonical example is the brute-force computation -# of the top k nearest neighbors by leveraging the arg-k-min reduction. -# -# The reduction takes a matrix of pairwise distances between rows of X and Y -# as input and outputs an aggregate data-structure for each row of X. The -# aggregate values are typically smaller than the number of rows in Y, hence -# the term reduction. -# -# For computational reasons, the reduction are performed on the fly on chunks -# of rows of X and Y so as to keep intermediate data-structures in CPU cache -# and avoid unnecessary round trips of large distance arrays with the RAM -# that would otherwise severely degrade the speed by making the overall -# processing memory-bound. -# -# Finally, the routines follow a generic parallelization template to process -# chunks of data with OpenMP loops (via Cython prange), either on rows of X -# or rows of Y depending on their respective sizes. -# -# -# Dispatching to specialized implementations -# ------------------------------------------ -# -# Dispatchers are meant to be used in the Python code. Under the hood, a -# dispatcher must only define the logic to choose at runtime to the correct -# dtype-specialized :class:`PairwiseDistancesReduction` implementation based -# on the dtype of X and of Y. -# -# -# High-level diagram -# ------------------ -# -# Legend: -# -# A ---⊳ B: A inherits from B -# A ---x B: A dispatches on B -# -# -# (base dispatcher) -# PairwiseDistancesReduction -# ∆ -# | -# | -# +-----------------+-----------------+ -# | | -# (dispatcher) (dispatcher) -# PairwiseDistancesArgKmin PairwiseDistancesRadiusNeighbors -# | | -# | | -# | | -# | (64bit implem.) | -# | PairwiseDistancesReduction64 | -# | ∆ | -# | | | -# | | | -# | +-----------------+-----------------+ | -# | | | | -# | | | | -# x | | x -# PairwiseDistancesArgKmin64 PairwiseDistancesRadiusNeighbors64 -# | ∆ ∆ | -# | | | | -# x | | | -# FastEuclideanPairwiseDistancesArgKmin64 | | -# | | -# | x -# FastEuclideanPairwiseDistancesRadiusNeighbors64 -# -# For instance :class:`PairwiseDistancesArgKmin`, dispatches to -# :class:`PairwiseDistancesArgKmin64` if X and Y are both dense NumPy arrays -# with a float64 dtype. -# -# In addition, if the metric parameter is set to "euclidean" or "sqeuclidean", -# :class:`PairwiseDistancesArgKmin64` further dispatches to -# :class:`FastEuclideanPairwiseDistancesArgKmin64` a specialized subclass -# to optimally handle the Euclidean distance case using the Generalized Matrix -# Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). -from abc import abstractmethod - -cimport numpy as cnp -import numpy as np -import warnings - -from .. import get_config -from libc.stdlib cimport free, malloc -from libc.float cimport DBL_MAX -from libcpp.memory cimport shared_ptr, make_shared -from libcpp.vector cimport vector -from cython cimport final -from cython.operator cimport dereference as deref -from cython.parallel cimport parallel, prange - -from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair -from ..utils._cython_blas cimport ( - BLAS_Order, - BLAS_Trans, - ColMajor, - NoTrans, - RowMajor, - Trans, - _dot, - _gemm, -) -from ..utils._heap cimport heap_push -from ..utils._sorting cimport simultaneous_sort -from ..utils._openmp_helpers cimport _openmp_thread_num -from ..utils._typedefs cimport ITYPE_t, DTYPE_t -from ..utils._vector_sentinel cimport vector_to_nd_array - -from numbers import Integral, Real -from typing import List -from scipy.sparse import issparse -from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING -from ..utils import check_scalar, _in_unstable_openblas_configuration -from ..utils.fixes import threadpool_limits -from ..utils._openmp_helpers import _openmp_effective_n_threads -from ..utils._typedefs import ITYPE, DTYPE - -cnp.import_array() - -# TODO: change for `libcpp.algorithm.move` once Cython 3 is used -# Introduction in Cython: -# https://github.com/cython/cython/blob/05059e2a9b89bf6738a7750b905057e5b1e3fe2e/Cython/Includes/libcpp/algorithm.pxd#L47 #noqa -cdef extern from "" namespace "std" nogil: - OutputIt move[InputIt, OutputIt](InputIt first, InputIt last, OutputIt d_first) except + #noqa - -###################### -## std::vector to np.ndarray coercion -# As type covariance is not supported for C++ containers via Cython, -# we need to redefine fused types. -ctypedef fused vector_DITYPE_t: - vector[ITYPE_t] - vector[DTYPE_t] - - -ctypedef fused vector_vector_DITYPE_t: - vector[vector[ITYPE_t]] - vector[vector[DTYPE_t]] - - -cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( - shared_ptr[vector_vector_DITYPE_t] vecs -): - """Coerce a std::vector of std::vector to a ndarray of ndarray.""" - cdef: - ITYPE_t n = deref(vecs).size() - cnp.ndarray[object, ndim=1] nd_arrays_of_nd_arrays = np.empty(n, dtype=np.ndarray) - - for i in range(n): - nd_arrays_of_nd_arrays[i] = vector_to_nd_array(&(deref(vecs)[i])) - - return nd_arrays_of_nd_arrays - -##################### -# Dispatchers - -class PairwiseDistancesReduction: - """Abstract base dispatcher for pairwise distance computation & reduction. - - Each dispatcher extending the base :class:`PairwiseDistancesReduction` - dispatcher must implement the :meth:`compute` classmethod. - """ - - @classmethod - def valid_metrics(cls) -> List[str]: - excluded = { - "pyfunc", # is relatively slow because we need to coerce data as np arrays - "mahalanobis", # is numerically unstable - # TODO: In order to support discrete distance metrics, we need to have a - # stable simultaneous sort which preserves the order of the input. - # The best might be using std::stable_sort and a Comparator taking an - # Arrays of Structures instead of Structure of Arrays (currently used). - "hamming", - *BOOL_METRICS, - } - return sorted(set(METRIC_MAPPING.keys()) - excluded) - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - """Return True if the PairwiseDistancesReduction can be used for the - given parameters. - - Parameters - ---------- - X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) - Input data. - - Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) - Input data. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - Returns - ------- - True if the PairwiseDistancesReduction can be used, else False. - """ - dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 - return (get_config().get("enable_cython_pairwise_dist", True) and - not issparse(X) and not issparse(Y) and dtypes_validity and - metric in cls.valid_metrics()) - - @classmethod - @abstractmethod - def compute( - cls, - X, - Y, - **kwargs, - ): - """Compute the reduction. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - **kwargs : additional parameters for the reduction - - Notes - ----- - This method is an abstract class method: it has to be implemented - for all - - """ - -class PairwiseDistancesArgKmin(PairwiseDistancesReduction): - """Compute the argkmin of row vectors of X on the ones of Y. - - For each row vector of X, computes the indices of k first the rows - vectors of Y with the smallest distances. - - PairwiseDistancesArgKmin is typically used to perform - bruteforce k-nearest neighbors queries. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - k, - metric="euclidean", - chunk_size=None, - metric_kwargs=None, - strategy=None, - return_distance=False, - ): - """Compute the argkmin reduction. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - k : int - The k for the argkmin reduction. - - metric : str, default='euclidean' - The distance metric to use for argkmin. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its - argkmin if set to True. - - Returns - ------- - If return_distance=False: - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - If return_distance=True: - - argkmin_distances : ndarray of shape (n_samples_X, k) - Distances to the argkmin for each vector in X. - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - Notes - ----- - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesArgKmin64.compute( - X=X, - Y=Y, - k=k, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - - -class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): - """Compute radius-based neighbors for two sets of vectors. - - For each row-vector X[i] of the queries X, find all the indices j of - row-vectors in Y such that: - - dist(X[i], Y[j]) <= radius - - The distance function `dist` depends on the values of the `metric` - and `metric_kwargs` parameters. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - radius, - metric="euclidean", - chunk_size=None, - metric_kwargs=None, - strategy=None, - return_distance=False, - sort_results=False, - ): - """Return the results of the reduction for the given arguments. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - radius : float - The radius defining the neighborhood. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its neighbors if set to True. - - sort_results : boolean, default=False - Sort results with respect to distances between each X vector and its - neighbors if set to True. - - Returns - ------- - If return_distance=False: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - If return_distance=True: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - neighbors_distances : ndarray of n_samples_X ndarray - Distances to the neighbors for each vector in X. - - Notes - ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private dtype-specialized implementation of - :class:`PairwiseDistancesRadiusNeighborhood`. - - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. - - This allows entirely decoupling the API entirely from the - implementation details whilst maintaining RAII. - """ - # Note (jjerphan): Some design thoughts for future extensions. - # This factory comes to handle specialisations for the given arguments. - # For future work, this might can be an entrypoint to specialise operations - # for various backend and/or hardware and/or datatypes, and/or fused - # {sparse, dense}-datasetspair etc. - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesRadiusNeighborhood64.compute( - X=X, - Y=Y, - radius=radius, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - -##################### -# dtype-specialized implementations - -cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( - const DTYPE_t[:, ::1] X, - ITYPE_t num_threads, -): - """Compute the squared euclidean norm of the rows of X in parallel. - - This is faster than using np.einsum("ij, ij->i") even when using a single thread. - """ - cdef: - # Casting for X to remove the const qualifier is needed because APIs - # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' - # const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * X_ptr = &X[0, 0] - ITYPE_t i = 0 - ITYPE_t n = X.shape[0] - ITYPE_t d = X.shape[1] - DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) - - for i in prange(n, schedule='static', nogil=True, num_threads=num_threads): - squared_row_norms[i] = _dot(d, X_ptr + i * d, 1, X_ptr + i * d, 1) - - return squared_row_norms - -cdef class GEMMTermComputer64: - """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. - - `FastEuclidean*` classes internally compute the squared Euclidean distances between - chunks of vectors X_c and Y_c using the following decomposition: - - - ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - - - This helper class is in charge of wrapping the common logic to compute - the middle term `- 2 X_c_i.Y_c_j^T` with a call to GEMM, which has a high - arithmetic intensity. - """ - cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y - - ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - ITYPE_t dist_middle_terms_chunks_size - ITYPE_t n_features - ITYPE_t chunk_size - - # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM - vector[vector[DTYPE_t]] dist_middle_terms_chunks - - def __init__(self, - DTYPE_t[:, ::1] X, - DTYPE_t[:, ::1] Y, - ITYPE_t effective_n_threads, - ITYPE_t chunks_n_threads, - ITYPE_t dist_middle_terms_chunks_size, - ITYPE_t n_features, - ITYPE_t chunk_size, - ): - self.X = X - self.Y = Y - self.effective_n_threads = effective_n_threads - self.chunks_n_threads = chunks_n_threads - self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size - self.n_features = n_features - self.chunk_size = chunk_size - - self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) - - - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - return - - cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: - self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - return - - cdef void _parallel_on_Y_init(self) nogil: - for thread_num in range(self.chunks_n_threads): - self.dist_middle_terms_chunks[thread_num].resize( - self.dist_middle_terms_chunks_size - ) - - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - return - - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num - ) nogil: - return - - cdef DTYPE_t * _compute_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] - const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() - - # Careful: LDA, LDB and LDC are given for F-ordered arrays - # in BLAS documentations, for instance: - # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa - # - # Here, we use their counterpart values to work with C-ordered arrays. - BLAS_Order order = RowMajor - BLAS_Trans ta = NoTrans - BLAS_Trans tb = Trans - ITYPE_t m = X_c.shape[0] - ITYPE_t n = Y_c.shape[0] - ITYPE_t K = X_c.shape[1] - DTYPE_t alpha = - 2. - # Casting for A and B to remove the const is needed because APIs exposed via - # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * A = &X_c[0, 0] - DTYPE_t * B = &Y_c[0, 0] - ITYPE_t lda = X_c.shape[1] - ITYPE_t ldb = X_c.shape[1] - DTYPE_t beta = 0. - ITYPE_t ldc = Y_c.shape[0] - - # dist_middle_terms = `-2 * X_c @ Y_c.T` - _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) - - return dist_middle_terms - -cdef class PairwiseDistancesReduction64: - """Base 64bit implementation of PairwiseDistancesReduction.""" - - cdef: - readonly DatasetsPair datasets_pair - - # The number of threads that can be used is stored in effective_n_threads. - # - # The number of threads to use in the parallelization strategy - # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: - # for small datasets, fewer threads might be needed to loop over pair of chunks. - # - # Hence, the number of threads that _will_ be used for looping over chunks - # is stored in chunks_n_threads, allowing solely using what we need. - # - # Thus, an invariant is: - # - # chunks_n_threads <= effective_n_threads - # - ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - - ITYPE_t n_samples_chunk, chunk_size - - ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk - ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk - - bint execute_in_parallel_on_Y - - def __init__( - self, - DatasetsPair datasets_pair, - chunk_size=None, - strategy=None, - ): - cdef: - ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks - - if chunk_size is None: - chunk_size = get_config().get("pairwise_dist_chunk_size", 256) - - self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) - - self.effective_n_threads = _openmp_effective_n_threads() - - self.datasets_pair = datasets_pair - - self.n_samples_X = datasets_pair.n_samples_X() - self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) - X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk - X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk - self.X_n_chunks = X_n_full_chunks + (X_n_samples_remainder != 0) - - if X_n_samples_remainder != 0: - self.X_n_samples_last_chunk = X_n_samples_remainder - else: - self.X_n_samples_last_chunk = self.X_n_samples_chunk - - self.n_samples_Y = datasets_pair.n_samples_Y() - self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) - Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk - Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk - self.Y_n_chunks = Y_n_full_chunks + (Y_n_samples_remainder != 0) - - if Y_n_samples_remainder != 0: - self.Y_n_samples_last_chunk = Y_n_samples_remainder - else: - self.Y_n_samples_last_chunk = self.Y_n_samples_chunk - - if strategy is None: - strategy = get_config().get("pairwise_dist_parallel_strategy", 'auto') - - if strategy not in ('parallel_on_X', 'parallel_on_Y', 'auto'): - raise RuntimeError(f"strategy must be 'parallel_on_X, 'parallel_on_Y', " - f"or 'auto', but currently strategy='{self.strategy}'.") - - if strategy == 'auto': - # This is a simple heuristic whose constant for the - # comparison has been chosen based on experiments. - if 4 * self.chunk_size * self.effective_n_threads < self.n_samples_X: - strategy = 'parallel_on_X' - else: - strategy = 'parallel_on_Y' - - self.execute_in_parallel_on_Y = strategy == "parallel_on_Y" - - # Not using less, not using more. - self.chunks_n_threads = min( - self.Y_n_chunks if self.execute_in_parallel_on_Y else self.X_n_chunks, - self.effective_n_threads, - ) - - @final - cdef void _parallel_on_X(self) nogil: - """Compute the pairwise distances of each row vector of X on Y - by parallelizing computation on the outer loop on chunks of X - and reduce them. - - This strategy dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - Private datastructures are modified internally by threads. - - Private template methods can be implemented on subclasses to - interact with those datastructures at various stages. - """ - cdef: - ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx - ITYPE_t thread_num - - with nogil, parallel(num_threads=self.chunks_n_threads): - thread_num = _openmp_thread_num() - - # Allocating thread datastructures - self._parallel_on_X_parallel_init(thread_num) - - for X_chunk_idx in prange(self.X_n_chunks, schedule='static'): - X_start = X_chunk_idx * self.X_n_samples_chunk - if X_chunk_idx == self.X_n_chunks - 1: - X_end = X_start + self.X_n_samples_last_chunk - else: - X_end = X_start + self.X_n_samples_chunk - - # Reinitializing thread datastructures for the new X chunk - # If necessary, upcast X[X_start:X_end] to 64bit - self._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - for Y_chunk_idx in range(self.Y_n_chunks): - Y_start = Y_chunk_idx * self.Y_n_samples_chunk - if Y_chunk_idx == self.Y_n_chunks - 1: - Y_end = Y_start + self.Y_n_samples_last_chunk - else: - Y_end = Y_start + self.Y_n_samples_chunk - - # If necessary, upcast Y[Y_start:Y_end] to 64bit - self._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - - self._compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - - # Adjusting thread datastructures on the full pass on Y - self._parallel_on_X_prange_iter_finalize(thread_num, X_start, X_end) - - # end: for X_chunk_idx - - # Deallocating thread datastructures - self._parallel_on_X_parallel_finalize(thread_num) - - # end: with nogil, parallel - return - - @final - cdef void _parallel_on_Y(self) nogil: - """Compute the pairwise distances of each row vector of X on Y - by parallelizing computation on the inner loop on chunks of Y - and reduce them. - - This strategy dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - Private datastructures are modified internally by threads. - - Private template methods can be implemented on subclasses to - interact with those datastructures at various stages. - """ - cdef: - ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx - ITYPE_t thread_num - - # Allocating datastructures shared by all threads - self._parallel_on_Y_init() - - for X_chunk_idx in range(self.X_n_chunks): - X_start = X_chunk_idx * self.X_n_samples_chunk - if X_chunk_idx == self.X_n_chunks - 1: - X_end = X_start + self.X_n_samples_last_chunk - else: - X_end = X_start + self.X_n_samples_chunk - - with nogil, parallel(num_threads=self.chunks_n_threads): - thread_num = _openmp_thread_num() - - # Initializing datastructures used in this thread - # If necessary, upcast X[X_start:X_end] to 64bit - self._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): - Y_start = Y_chunk_idx * self.Y_n_samples_chunk - if Y_chunk_idx == self.Y_n_chunks - 1: - Y_end = Y_start + self.Y_n_samples_last_chunk - else: - Y_end = Y_start + self.Y_n_samples_chunk - - # If necessary, upcast Y[Y_start:Y_end] to 64bit - self._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - - self._compute_and_reduce_distances_on_chunks( - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - # end: prange - - # Note: we don't need a _parallel_on_Y_finalize similarly. - # This can be introduced if needed. - - # end: with nogil, parallel - - # Synchronizing the thread datastructures with the main ones - self._parallel_on_Y_synchronize(X_start, X_end) - - # end: for X_chunk_idx - # Deallocating temporary datastructures and adjusting main datastructures - self._parallel_on_Y_finalize() - return - - # Placeholder methods which have to be implemented - - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - """Compute the pairwise distances on two chunks of X and Y and reduce them. - - This is THE core computational method of PairwiseDistanceReductions64. - This must be implemented in subclasses agnostically from the parallelization - strategies. - """ - return - - def _finalize_results(self, bint return_distance): - """Callback adapting datastructures before returning results. - - This must be implemented in subclasses. - """ - return None - - # Placeholder methods which can be implemented - - cdef void compute_exact_distances(self) nogil: - """Convert rank-preserving distances to exact distances or recompute them.""" - return - - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - """Allocate datastructures used in a thread given its number.""" - return - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Initialise datastructures used in a thread given its number.""" - return - - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. - - This is eventually used to upcast X[X_start:X_end] to 64bit. - """ - return - - cdef void _parallel_on_X_prange_iter_finalize( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Interact with datastructures after a reduction on chunks.""" - return - - cdef void _parallel_on_X_parallel_finalize( - self, - ITYPE_t thread_num - ) nogil: - """Interact with datastructures after executing all the reductions.""" - return - - cdef void _parallel_on_Y_init( - self, - ) nogil: - """Allocate datastructures used in all threads.""" - return - - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Initialise datastructures used in a thread given its number.""" - return - - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. - - This is eventually used to upcast Y[Y_start:Y_end] to 64bit. - """ - return - - cdef void _parallel_on_Y_synchronize( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - """Update thread datastructures before leaving a parallel region.""" - return - - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - """Update datastructures after executing all the reductions.""" - return - -cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" - - cdef: - ITYPE_t k - - ITYPE_t[:, ::1] argkmin_indices - DTYPE_t[:, ::1] argkmin_distances - - # Used as array of pointers to private datastructures used in threads. - DTYPE_t ** heaps_r_distances_chunks - ITYPE_t ** heaps_indices_chunks - - @classmethod - def compute( - cls, - X, - Y, - ITYPE_t k, - str metric="euclidean", - chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - ): - """Compute the argkmin reduction. - - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin64`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - - No instance should directly be created outside of this class method. - """ - if ( - metric in ("euclidean", "sqeuclidean") - and not issparse(X) - and not issparse(Y) - ): - # Specialized implementation with improved arithmetic intensity - # and vector instructions (SIMD) by processing several vectors - # at time to leverage a call to the BLAS GEMM routine as explained - # in more details in the docstring. - use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesArgKmin64( - X=X, Y=Y, k=k, - use_squared_distances=use_squared_distances, - chunk_size=chunk_size, - strategy=strategy, - metric_kwargs=metric_kwargs, - ) - else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesArgKmin64( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - ) - - # Limit the number of threads in second level of nested parallelism for BLAS - # to avoid threads over-subscription (in GEMM for instance). - with threadpool_limits(limits=1, user_api="blas"): - if pda.execute_in_parallel_on_Y: - pda._parallel_on_Y() - else: - pda._parallel_on_X() - - return pda._finalize_results(return_distance) - - def __init__( - self, - DatasetsPair datasets_pair, - chunk_size=None, - strategy=None, - ITYPE_t k=1, - ): - super().__init__( - datasets_pair=datasets_pair, - chunk_size=chunk_size, - strategy=strategy, - ) - self.k = check_scalar(k, "k", Integral, min_val=1) - - # Allocating pointers to datastructures but not the datastructures themselves. - # There are as many pointers as effective threads. - # - # For the sake of explicitness: - # - when parallelizing on X, the pointers of those heaps are referencing - # (with proper offsets) addresses of the two main heaps (see below) - # - when parallelizing on Y, the pointers of those heaps are referencing - # small heaps which are thread-wise-allocated and whose content will be - # merged with the main heaps'. - self.heaps_r_distances_chunks = malloc( - sizeof(DTYPE_t *) * self.chunks_n_threads - ) - self.heaps_indices_chunks = malloc( - sizeof(ITYPE_t *) * self.chunks_n_threads - ) - - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin64.compute`. - self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) - self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) - - def __dealloc__(self): - if self.heaps_indices_chunks is not NULL: - free(self.heaps_indices_chunks) - - if self.heaps_r_distances_chunks is not NULL: - free(self.heaps_r_distances_chunks) - - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - ITYPE_t n_samples_X = X_end - X_start - ITYPE_t n_samples_Y = Y_end - Y_start - DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] - ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] - - # Pushing the distances and their associated indices on a heap - # which by construction will keep track of the argkmin. - for i in range(n_samples_X): - for j in range(n_samples_Y): - heap_push( - heaps_r_distances + i * self.k, - heaps_indices + i * self.k, - self.k, - self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), - Y_start + j, - ) - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - # As this strategy is embarrassingly parallel, we can set each - # thread's heaps pointer to the proper position on the main heaps. - self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0] - self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] - - @final - cdef void _parallel_on_X_prange_iter_finalize( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - cdef: - ITYPE_t idx, jdx - - # Sorting the main heaps portion associated to `X[X_start:X_end]` - # in ascending order w.r.t the distances. - for idx in range(X_end - X_start): - simultaneous_sort( - self.heaps_r_distances_chunks[thread_num] + idx * self.k, - self.heaps_indices_chunks[thread_num] + idx * self.k, - self.k - ) - - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef: - # Maximum number of scalar elements (the last chunks can be smaller) - ITYPE_t heaps_size = self.X_n_samples_chunk * self.k - ITYPE_t thread_num - - # The allocation is done in parallel for data locality purposes: this way - # the heaps used in each threads are allocated in pages which are closer - # to the CPU core used by the thread. - # See comments about First Touch Placement Policy: - # https://www.openmp.org/wp-content/uploads/openmp-webinar-vanderPas-20210318.pdf #noqa - for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True, - num_threads=self.chunks_n_threads): - # As chunks of X are shared across threads, so must their - # heaps. To solve this, each thread has its own heaps - # which are then synchronised back in the main ones. - self.heaps_r_distances_chunks[thread_num] = malloc( - heaps_size * sizeof(DTYPE_t) - ) - self.heaps_indices_chunks[thread_num] = malloc( - heaps_size * sizeof(ITYPE_t) - ) - - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - # Initialising heaps (memset can't be used here) - for idx in range(self.X_n_samples_chunk * self.k): - self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX - self.heaps_indices_chunks[thread_num][idx] = -1 - - @final - cdef void _parallel_on_Y_synchronize( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - cdef: - ITYPE_t idx, jdx, thread_num - with nogil, parallel(num_threads=self.effective_n_threads): - # Synchronising the thread heaps with the main heaps. - # This is done in parallel sample-wise (no need for locks). - # - # This might break each thread's data locality as each heap which - # was allocated in a thread is being now being used in several threads. - # - # Still, this parallel pattern has shown to be efficient in practice. - for idx in prange(X_end - X_start, schedule="static"): - for thread_num in range(self.chunks_n_threads): - for jdx in range(self.k): - heap_push( - &self.argkmin_distances[X_start + idx, 0], - &self.argkmin_indices[X_start + idx, 0], - self.k, - self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], - self.heaps_indices_chunks[thread_num][idx * self.k + jdx], - ) - - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - cdef: - ITYPE_t idx, thread_num - - with nogil, parallel(num_threads=self.chunks_n_threads): - # Deallocating temporary datastructures - for thread_num in prange(self.chunks_n_threads, schedule='static'): - free(self.heaps_r_distances_chunks[thread_num]) - free(self.heaps_indices_chunks[thread_num]) - - # Sorting the main in ascending order w.r.t the distances. - # This is done in parallel sample-wise (no need for locks). - for idx in prange(self.n_samples_X, schedule='static'): - simultaneous_sort( - &self.argkmin_distances[idx, 0], - &self.argkmin_indices[idx, 0], - self.k, - ) - return - - cdef void compute_exact_distances(self) nogil: - cdef: - ITYPE_t i, j - ITYPE_t[:, ::1] Y_indices = self.argkmin_indices - DTYPE_t[:, ::1] distances = self.argkmin_distances - for i in prange(self.n_samples_X, schedule='static', nogil=True, - num_threads=self.effective_n_threads): - for j in range(self.k): - distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. - max(distances[i, j], 0.) - ) - - def _finalize_results(self, bint return_distance=False): - if return_distance: - # We need to recompute distances because we relied on - # surrogate distances for the reduction. - self.compute_exact_distances() - - # Values are returned identically to the way `KNeighborsMixin.kneighbors` - # returns values. This is counter-intuitive but this allows not using - # complex adaptations where `PairwiseDistancesArgKmin64.compute` is called. - return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) - - return np.asarray(self.argkmin_indices) - - -cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" - cdef: - GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesArgKmin64.is_usable_for(X, Y, metric) and - not _in_unstable_openblas_configuration()) - - def __init__( - self, - X, - Y, - ITYPE_t k, - bint use_squared_distances=False, - chunk_size=None, - strategy=None, - metric_kwargs=None, - ): - if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.", - UserWarning, - stacklevel=3, - ) - - super().__init__( - # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), - chunk_size=chunk_size, - strategy=strategy, - k=k, - ) - # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair - cdef: - DenseDenseDatasetsPair datasets_pair = ( - self.datasets_pair - ) - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - - self.gemm_term_computer = GEMMTermComputer64( - datasets_pair.X, - datasets_pair.Y, - self.effective_n_threads, - self.chunks_n_threads, - dist_middle_terms_chunks_size, - n_features=datasets_pair.X.shape[1], - chunk_size=self.chunk_size, - ) - - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") - else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) - - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) - ) - self.use_squared_distances = use_squared_distances - - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - PairwiseDistancesArgKmin64.compute_exact_distances(self) - - @final - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_parallel_init(self, thread_num) - self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) - - - @final - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - - @final - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num, - ) - - - @final - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef ITYPE_t thread_num - PairwiseDistancesArgKmin64._parallel_on_Y_init(self) - self.gemm_term_computer._parallel_on_Y_init() - - - @final - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - - @final - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesArgKmin64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - - @final - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start - DTYPE_t * dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num] - ITYPE_t * heaps_indices = self.heaps_indices_chunks[thread_num] - - - # Pushing the distance and their associated indices on heaps - # which keep tracks of the argkmin. - for i in range(n_X): - for j in range(n_Y): - heap_push( - heaps_r_distances + i * self.k, - heaps_indices + i * self.k, - self.k, - # Using the squared euclidean distance as the rank-preserving distance: - # - # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - ( - self.X_norm_squared[i + X_start] + - dist_middle_terms[i * n_Y + j] + - self.Y_norm_squared[j + Y_start] - ), - j + Y_start, - ) - - -cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" - - cdef: - DTYPE_t radius - - # DistanceMetric compute rank-preserving surrogate distance via rdist - # which are proxies necessitating less computations. - # We get the equivalent for the radius to be able to compare it against - # vectors' rank-preserving surrogate distances. - DTYPE_t r_radius - - # Neighbors indices and distances are returned as np.ndarrays of np.ndarrays. - # - # For this implementation, we want resizable buffers which we will wrap - # into numpy arrays at the end. std::vector comes as a handy container - # for interacting efficiently with resizable buffers. - # - # Though it is possible to access their buffer address with - # std::vector::data, they can't be stolen: buffers lifetime - # is tied to their std::vector and are deallocated when - # std::vectors are. - # - # To solve this, we dynamically allocate std::vectors and then - # encapsulate them in a StdVectorSentinel responsible for - # freeing them when the associated np.ndarray is freed. - # - # Shared pointers (defined via shared_ptr) are use for safer memory management. - # Unique pointers (defined via unique_ptr) can't be used as datastructures - # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. - shared_ptr[vector[vector[ITYPE_t]]] neigh_indices - shared_ptr[vector[vector[DTYPE_t]]] neigh_distances - - # Used as array of pointers to private datastructures used in threads. - vector[shared_ptr[vector[vector[ITYPE_t]]]] neigh_indices_chunks - vector[shared_ptr[vector[vector[DTYPE_t]]]] neigh_distances_chunks - - bint sort_results - - @classmethod - def compute( - cls, - X, - Y, - DTYPE_t radius, - str metric="euclidean", - chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - bint sort_results=False, - ): - """Compute the radius-neighbors reduction. - - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesRadiusNeighborhood64`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - - No instance should directly be created outside of this class method. - """ - if ( - metric in ("euclidean", "sqeuclidean") - and not issparse(X) - and not issparse(Y) - ): - # Specialized implementation with improved arithmetic intensity - # and vector instructions (SIMD) by processing several vectors - # at time to leverage a call to the BLAS GEMM routine as explained - # in more details in the docstring. - use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesRadiusNeighborhood64( - X=X, Y=Y, radius=radius, - use_squared_distances=use_squared_distances, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - ) - else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesRadiusNeighborhood64( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - radius=radius, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - ) - - # Limit the number of threads in second level of nested parallelism for BLAS - # to avoid threads over-subscription (in GEMM for instance). - with threadpool_limits(limits=1, user_api="blas"): - if pda.execute_in_parallel_on_Y: - pda._parallel_on_Y() - else: - pda._parallel_on_X() - - return pda._finalize_results(return_distance) - - - def __init__( - self, - DatasetsPair datasets_pair, - DTYPE_t radius, - chunk_size=None, - strategy=None, - sort_results=False, - metric_kwargs=None, - ): - super().__init__( - datasets_pair=datasets_pair, - chunk_size=chunk_size, - strategy=strategy, - ) - - self.radius = check_scalar(radius, "radius", Real, min_val=0) - self.r_radius = self.datasets_pair.distance_metric._dist_to_rdist(radius) - self.sort_results = sort_results - - # Allocating pointers to datastructures but not the datastructures themselves. - # There are as many pointers as effective threads. - # - # For the sake of explicitness: - # - when parallelizing on X, the pointers of those heaps are referencing - # self.neigh_distances and self.neigh_indices - # - when parallelizing on Y, the pointers of those heaps are referencing - # std::vectors of std::vectors which are thread-wise-allocated and whose - # content will be merged into self.neigh_distances and self.neigh_indices. - self.neigh_distances_chunks = vector[shared_ptr[vector[vector[DTYPE_t]]]]( - self.chunks_n_threads - ) - self.neigh_indices_chunks = vector[shared_ptr[vector[vector[ITYPE_t]]]]( - self.chunks_n_threads - ) - - # Temporary datastructures which will be coerced to numpy arrays on before - # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. - self.neigh_distances = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) - self.neigh_indices = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) - - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t r_dist_i_j - - for i in range(X_start, X_end): - for j in range(Y_start, Y_end): - r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) - if r_dist_i_j <= self.r_radius: - deref(self.neigh_distances_chunks[thread_num])[i].push_back(r_dist_i_j) - deref(self.neigh_indices_chunks[thread_num])[i].push_back(j) - - def _finalize_results(self, bint return_distance=False): - if return_distance: - # We need to recompute distances because we relied on - # surrogate distances for the reduction. - self.compute_exact_distances() - return ( - coerce_vectors_to_nd_arrays(self.neigh_distances), - coerce_vectors_to_nd_arrays(self.neigh_indices), - ) - - return coerce_vectors_to_nd_arrays(self.neigh_indices) - - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - - # As this strategy is embarrassingly parallel, we can set the - # thread vectors' pointers to the main vectors'. - self.neigh_distances_chunks[thread_num] = self.neigh_distances - self.neigh_indices_chunks[thread_num] = self.neigh_indices - - @final - cdef void _parallel_on_X_prange_iter_finalize( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - cdef: - ITYPE_t idx, jdx - - # Sorting neighbors for each query vector of X - if self.sort_results: - for idx in range(X_start, X_end): - simultaneous_sort( - deref(self.neigh_distances)[idx].data(), - deref(self.neigh_indices)[idx].data(), - deref(self.neigh_indices)[idx].size() - ) - - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef: - ITYPE_t thread_num - # As chunks of X are shared across threads, so must datastructures to avoid race - # conditions: each thread has its own vectors of n_samples_X vectors which are - # then merged back in the main n_samples_X vectors. - for thread_num in range(self.chunks_n_threads): - self.neigh_distances_chunks[thread_num] = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) - self.neigh_indices_chunks[thread_num] = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) - - @final - cdef void _merge_vectors( - self, - ITYPE_t idx, - ITYPE_t num_threads, - ) nogil: - cdef: - ITYPE_t thread_num - ITYPE_t idx_n_elements = 0 - ITYPE_t last_element_idx = deref(self.neigh_indices)[idx].size() - - # Resizing buffers only once for the given number of elements. - for thread_num in range(num_threads): - idx_n_elements += deref(self.neigh_distances_chunks[thread_num])[idx].size() - - deref(self.neigh_distances)[idx].resize(last_element_idx + idx_n_elements) - deref(self.neigh_indices)[idx].resize(last_element_idx + idx_n_elements) - - # Moving the elements by range using the range first element - # as the reference for the insertion. - for thread_num in range(num_threads): - move( - deref(self.neigh_distances_chunks[thread_num])[idx].begin(), - deref(self.neigh_distances_chunks[thread_num])[idx].end(), - deref(self.neigh_distances)[idx].begin() + last_element_idx - ) - move( - deref(self.neigh_indices_chunks[thread_num])[idx].begin(), - deref(self.neigh_indices_chunks[thread_num])[idx].end(), - deref(self.neigh_indices)[idx].begin() + last_element_idx - ) - last_element_idx += deref(self.neigh_distances_chunks[thread_num])[idx].size() - - - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - cdef: - ITYPE_t idx, jdx, thread_num, idx_n_element, idx_current - - with nogil, parallel(num_threads=self.effective_n_threads): - # Merge vectors used in threads into the main ones. - # This is done in parallel sample-wise (no need for locks) - # using dynamic scheduling because we might not have - # the same number of neighbors for each query vector. - for idx in prange(self.n_samples_X, schedule='static'): - self._merge_vectors(idx, self.chunks_n_threads) - - # The content of the vector have been std::moved. - # Hence they can't be used anymore and can be deleted. - # Their deletion is carried out automatically as the - # implementation relies on shared pointers. - - # Sort in parallel in ascending order w.r.t the distances if requested. - if self.sort_results: - for idx in prange(self.n_samples_X, schedule='static'): - simultaneous_sort( - deref(self.neigh_distances)[idx].data(), - deref(self.neigh_indices)[idx].data(), - deref(self.neigh_indices)[idx].size() - ) - - return - - cdef void compute_exact_distances(self) nogil: - """Convert rank-preserving distances to pairwise distances in parallel.""" - cdef: - ITYPE_t i, j - - for i in prange(self.n_samples_X, nogil=True, schedule='static', - num_threads=self.effective_n_threads): - for j in range(deref(self.neigh_indices)[i].size()): - deref(self.neigh_distances)[i][j] = ( - self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. - max(deref(self.neigh_distances)[i][j], 0.) - ) - ) - - -cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" - cdef: - GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesRadiusNeighborhood64.is_usable_for(X, Y, metric) - and not _in_unstable_openblas_configuration()) - - def __init__( - self, - X, - Y, - DTYPE_t radius, - bint use_squared_distances=False, - chunk_size=None, - strategy=None, - sort_results=False, - metric_kwargs=None, - ): - if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (FastEuclideanPairwiseDistancesRadiusNeighborhood) and will be ignored.", - UserWarning, - stacklevel=3, - ) - - super().__init__( - # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), - radius=radius, - chunk_size=chunk_size, - strategy=strategy, - sort_results=sort_results, - metric_kwargs=metric_kwargs, - ) - # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair - cdef: - DenseDenseDatasetsPair datasets_pair = self.datasets_pair - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - - self.gemm_term_computer = GEMMTermComputer64( - datasets_pair.X, - datasets_pair.Y, - self.effective_n_threads, - self.chunks_n_threads, - dist_middle_terms_chunks_size, - n_features=datasets_pair.X.shape[1], - chunk_size=self.chunk_size, - ) - - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") - else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) - - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) - ) - self.use_squared_distances = use_squared_distances - - if use_squared_distances: - # In this specialisation and this setup, the value passed to the radius is - # already considered to be the adapted radius, so we overwrite it. - self.r_radius = radius - - @final - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_parallel_init(self, thread_num) - self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) - - @final - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num, - ) - - @final - cdef void _parallel_on_Y_init( - self, - ) nogil: - cdef ITYPE_t thread_num - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_init(self) - self.gemm_term_computer._parallel_on_Y_init() - - @final - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) - self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - PairwiseDistancesRadiusNeighborhood64.compute_exact_distances(self) - - @final - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t squared_dist_i_j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start - DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - # Pushing the distance and their associated indices in vectors. - for i in range(n_X): - for j in range(n_Y): - # Using the squared euclidean distance as the rank-preserving distance: - # - # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - squared_dist_i_j = ( - self.X_norm_squared[i + X_start] - + dist_middle_terms[i * n_Y + j] - + self.Y_norm_squared[j + Y_start] - ) - if squared_dist_i_j <= self.r_radius: - deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) - deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py new file mode 100644 index 0000000000000..947f25c6c71e9 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -0,0 +1,102 @@ +# Pairwise Distances Reductions +# ============================= +# +# Author: Julien Jerphanion +# +# Overview +# -------- +# +# This module provides routines to compute pairwise distances between a set +# of row vectors of X and another set of row vectors of Y and apply a +# reduction on top. The canonical example is the brute-force computation +# of the top k nearest neighbors by leveraging the arg-k-min reduction. +# +# The reduction takes a matrix of pairwise distances between rows of X and Y +# as input and outputs an aggregate data-structure for each row of X. The +# aggregate values are typically smaller than the number of rows in Y, hence +# the term reduction. +# +# For computational reasons, the reduction are performed on the fly on chunks +# of rows of X and Y so as to keep intermediate data-structures in CPU cache +# and avoid unnecessary round trips of large distance arrays with the RAM +# that would otherwise severely degrade the speed by making the overall +# processing memory-bound. +# +# Finally, the routines follow a generic parallelization template to process +# chunks of data with OpenMP loops (via Cython prange), either on rows of X +# or rows of Y depending on their respective sizes. +# +# +# Dispatching to specialized implementations +# ------------------------------------------ +# +# Dispatchers are meant to be used in the Python code. Under the hood, a +# dispatcher must only define the logic to choose at runtime to the correct +# dtype-specialized :class:`PairwiseDistancesReduction` implementation based +# on the dtype of X and of Y. +# +# +# High-level diagram +# ------------------ +# +# Legend: +# +# A ---⊳ B: A inherits from B +# A ---x B: A dispatches on B +# +# +# (base dispatcher) +# PairwiseDistancesReduction +# ∆ +# | +# | +# +-----------------+-----------------+ +# | | +# (dispatcher) (dispatcher) +# PairwiseDistancesArgKmin PairwiseDistancesRadiusNeighbors +# | | +# | | +# | | +# | (64bit implem.) | +# | PairwiseDistancesReduction64 | +# | ∆ | +# | | | +# | | | +# | +-----------------+-----------------+ | +# | | | | +# | | | | +# x | | x +# PairwiseDistancesArgKmin64 PairwiseDistancesRadiusNeighbors64 +# | ∆ ∆ | +# | | | | +# x | | | +# FastEuclideanPairwiseDistancesArgKmin64 | | +# | | +# | x +# FastEuclideanPairwiseDistancesRadiusNeighbors64 +# +# For instance :class:`PairwiseDistancesArgKmin`, dispatches to +# :class:`PairwiseDistancesArgKmin64` if X and Y are both dense NumPy arrays +# with a float64 dtype. +# +# In addition, if the metric parameter is set to "euclidean" or "sqeuclidean", +# :class:`PairwiseDistancesArgKmin64` further dispatches to +# :class:`FastEuclideanPairwiseDistancesArgKmin64` a specialized subclass +# to optimally handle the Euclidean distance case using the Generalized Matrix +# Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). + + +from ._base import ( + PairwiseDistancesReduction, + _sqeuclidean_row_norms64, +) + +from ._argkmin import PairwiseDistancesArgKmin +from ._radius_neighborhood import PairwiseDistancesRadiusNeighborhood + +__all__ = [ + "PairwiseDistancesReduction", + "PairwiseDistancesArgKmin", + "PairwiseDistancesRadiusNeighborhood", + "_sqeuclidean_row_norms64", +] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd new file mode 100644 index 0000000000000..34d3339e1c9e0 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd @@ -0,0 +1,33 @@ +cimport numpy as cnp + +from ._base cimport ( + PairwiseDistancesReduction64, +) +from ._gemm_term_computer cimport GEMMTermComputer64 + +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +cnp.import_array() + +cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistancesArgKmin.""" + + cdef: + ITYPE_t k + + ITYPE_t[:, ::1] argkmin_indices + DTYPE_t[:, ::1] argkmin_distances + + # Used as array of pointers to private datastructures used in threads. + DTYPE_t ** heaps_r_distances_chunks + ITYPE_t ** heaps_indices_chunks + + +cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): + """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" + cdef: + GEMMTermComputer64 gemm_term_computer + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + bint use_squared_distances diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx new file mode 100644 index 0000000000000..f202ec37395fa --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx @@ -0,0 +1,625 @@ +cimport numpy as cnp + +from libc.stdlib cimport free, malloc +from libc.float cimport DBL_MAX +from cython cimport final +from cython.parallel cimport parallel, prange + +from ._base cimport ( + PairwiseDistancesReduction64, + _sqeuclidean_row_norms64, +) + +from ._datasets_pair cimport ( + DatasetsPair, + DenseDenseDatasetsPair, +) + +from ._gemm_term_computer cimport GEMMTermComputer64 + +from ...utils._heap cimport heap_push +from ...utils._sorting cimport simultaneous_sort +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +import numpy as np +import warnings + +from numbers import Integral +from scipy.sparse import issparse +from sklearn.utils import check_scalar, _in_unstable_openblas_configuration +from sklearn.utils.fixes import threadpool_limits +from ...utils._typedefs import ITYPE, DTYPE + +cnp.import_array() + +from ._base import PairwiseDistancesReduction + +class PairwiseDistancesArgKmin(PairwiseDistancesReduction): + """Compute the argkmin of row vectors of X on the ones of Y. + + For each row vector of X, computes the indices of k first the rows + vectors of Y with the smallest distances. + + PairwiseDistancesArgKmin is typically used to perform + bruteforce k-nearest neighbors queries. + + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def compute( + cls, + X, + Y, + ITYPE_t k, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + bint return_distance=False, + ): + """Compute the argkmin reduction. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + k : int + The k for the argkmin reduction. + + metric : str, default='euclidean' + The distance metric to use for argkmin. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + If return_distance=False: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + If return_distance=True: + - argkmin_distances : ndarray of shape (n_samples_X, k) + Distances to the argkmin for each vector in X. + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + Notes + ----- + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesArgKmin`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + """ + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesArgKmin64.compute( + X=X, + Y=Y, + k=k, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + return_distance=return_distance, + ) + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + +cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistancesArgKmin.""" + + @classmethod + def compute( + cls, + X, + Y, + ITYPE_t k, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + bint return_distance=False, + ): + """Compute the argkmin reduction. + + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesArgKmin64`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + + No instance should directly be created outside of this class method. + """ + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + # Specialized implementation with improved arithmetic intensity + # and vector instructions (SIMD) by processing several vectors + # at time to leverage a call to the BLAS GEMM routine as explained + # in more details in the docstring. + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesArgKmin64( + X=X, Y=Y, k=k, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + strategy=strategy, + metric_kwargs=metric_kwargs, + ) + else: + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pda = PairwiseDistancesArgKmin64( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pda.execute_in_parallel_on_Y: + pda._parallel_on_Y() + else: + pda._parallel_on_X() + + return pda._finalize_results(return_distance) + + def __init__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + strategy=None, + ITYPE_t k=1, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + strategy=strategy, + ) + self.k = check_scalar(k, "k", Integral, min_val=1) + + # Allocating pointers to datastructures but not the datastructures themselves. + # There are as many pointers as effective threads. + # + # For the sake of explicitness: + # - when parallelizing on X, the pointers of those heaps are referencing + # (with proper offsets) addresses of the two main heaps (see below) + # - when parallelizing on Y, the pointers of those heaps are referencing + # small heaps which are thread-wise-allocated and whose content will be + # merged with the main heaps'. + self.heaps_r_distances_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + self.heaps_indices_chunks = malloc( + sizeof(ITYPE_t *) * self.chunks_n_threads + ) + + # Main heaps which will be returned as results by `PairwiseDistancesArgKmin64.compute`. + self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) + self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) + + def __dealloc__(self): + if self.heaps_indices_chunks is not NULL: + free(self.heaps_indices_chunks) + + if self.heaps_r_distances_chunks is not NULL: + free(self.heaps_r_distances_chunks) + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + ITYPE_t n_samples_X = X_end - X_start + ITYPE_t n_samples_Y = Y_end - Y_start + DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Pushing the distances and their associated indices on a heap + # which by construction will keep track of the argkmin. + for i in range(n_samples_X): + for j in range(n_samples_Y): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), + Y_start + j, + ) + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + # As this strategy is embarrassingly parallel, we can set each + # thread's heaps pointer to the proper position on the main heaps. + self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0] + self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] + + @final + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + + # Sorting the main heaps portion associated to `X[X_start:X_end]` + # in ascending order w.r.t the distances. + for idx in range(X_end - X_start): + simultaneous_sort( + self.heaps_r_distances_chunks[thread_num] + idx * self.k, + self.heaps_indices_chunks[thread_num] + idx * self.k, + self.k + ) + + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef: + # Maximum number of scalar elements (the last chunks can be smaller) + ITYPE_t heaps_size = self.X_n_samples_chunk * self.k + ITYPE_t thread_num + + # The allocation is done in parallel for data locality purposes: this way + # the heaps used in each threads are allocated in pages which are closer + # to the CPU core used by the thread. + # See comments about First Touch Placement Policy: + # https://www.openmp.org/wp-content/uploads/openmp-webinar-vanderPas-20210318.pdf #noqa + for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True, + num_threads=self.chunks_n_threads): + # As chunks of X are shared across threads, so must their + # heaps. To solve this, each thread has its own heaps + # which are then synchronised back in the main ones. + self.heaps_r_distances_chunks[thread_num] = malloc( + heaps_size * sizeof(DTYPE_t) + ) + self.heaps_indices_chunks[thread_num] = malloc( + heaps_size * sizeof(ITYPE_t) + ) + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + # Initialising heaps (memset can't be used here) + for idx in range(self.X_n_samples_chunk * self.k): + self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX + self.heaps_indices_chunks[thread_num][idx] = -1 + + @final + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx, thread_num + with nogil, parallel(num_threads=self.effective_n_threads): + # Synchronising the thread heaps with the main heaps. + # This is done in parallel sample-wise (no need for locks). + # + # This might break each thread's data locality as each heap which + # was allocated in a thread is being now being used in several threads. + # + # Still, this parallel pattern has shown to be efficient in practice. + for idx in prange(X_end - X_start, schedule="static"): + for thread_num in range(self.chunks_n_threads): + for jdx in range(self.k): + heap_push( + &self.argkmin_distances[X_start + idx, 0], + &self.argkmin_indices[X_start + idx, 0], + self.k, + self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], + self.heaps_indices_chunks[thread_num][idx * self.k + jdx], + ) + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef: + ITYPE_t idx, thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + # Deallocating temporary datastructures + for thread_num in prange(self.chunks_n_threads, schedule='static'): + free(self.heaps_r_distances_chunks[thread_num]) + free(self.heaps_indices_chunks[thread_num]) + + # Sorting the main in ascending order w.r.t the distances. + # This is done in parallel sample-wise (no need for locks). + for idx in prange(self.n_samples_X, schedule='static'): + simultaneous_sort( + &self.argkmin_distances[idx, 0], + &self.argkmin_indices[idx, 0], + self.k, + ) + return + + cdef void compute_exact_distances(self) nogil: + cdef: + ITYPE_t i, j + ITYPE_t[:, ::1] Y_indices = self.argkmin_indices + DTYPE_t[:, ::1] distances = self.argkmin_distances + for i in prange(self.n_samples_X, schedule='static', nogil=True, + num_threads=self.effective_n_threads): + for j in range(self.k): + distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against eventual -0., causing nan production. + max(distances[i, j], 0.) + ) + + def _finalize_results(self, bint return_distance=False): + if return_distance: + # We need to recompute distances because we relied on + # surrogate distances for the reduction. + self.compute_exact_distances() + + # Values are returned identically to the way `KNeighborsMixin.kneighbors` + # returns values. This is counter-intuitive but this allows not using + # complex adaptations where `PairwiseDistancesArgKmin64.compute` is called. + return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) + + return np.asarray(self.argkmin_indices) + + +cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): + """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistancesArgKmin64.is_usable_for(X, Y, metric) and + not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + ITYPE_t k, + bint use_squared_distances=False, + chunk_size=None, + strategy=None, + metric_kwargs=None, + ): + if ( + metric_kwargs is not None and + len(metric_kwargs) > 0 and + "Y_norm_squared" not in metric_kwargs + ): + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " + f"usable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.", + UserWarning, + stacklevel=3, + ) + + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + chunk_size=chunk_size, + strategy=strategy, + k=k, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = ( + self.datasets_pair + ) + ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk + + self.gemm_term_computer = GEMMTermComputer64( + datasets_pair.X, + datasets_pair.Y, + self.effective_n_threads, + self.chunks_n_threads, + dist_middle_terms_chunks_size, + n_features=datasets_pair.X.shape[1], + chunk_size=self.chunk_size, + ) + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + else: + self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesArgKmin64.compute_exact_distances(self) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_X_parallel_init(self, thread_num) + self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) + + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) + + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin64._parallel_on_Y_init(self) + self.gemm_term_computer._parallel_on_Y_init() + + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t * dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t * heaps_indices = self.heaps_indices_chunks[thread_num] + + + # Pushing the distance and their associated indices on heaps + # which keep tracks of the argkmin. + for i in range(n_X): + for j in range(n_Y): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * n_Y + j] + + self.Y_norm_squared[j + Y_start] + ), + j + Y_start, + ) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd new file mode 100644 index 0000000000000..9f6ad45cb839a --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd @@ -0,0 +1,128 @@ +cimport numpy as cnp + +from cython cimport final + +from ._datasets_pair cimport DatasetsPair +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +cnp.import_array() + + +cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +) + +cdef class PairwiseDistancesReduction64: + """Base 64bit implementation of PairwiseDistancesReduction.""" + + cdef: + readonly DatasetsPair datasets_pair + + # The number of threads that can be used is stored in effective_n_threads. + # + # The number of threads to use in the parallelization strategy + # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: + # for small datasets, fewer threads might be needed to loop over pair of chunks. + # + # Hence, the number of threads that _will_ be used for looping over chunks + # is stored in chunks_n_threads, allowing solely using what we need. + # + # Thus, an invariant is: + # + # chunks_n_threads <= effective_n_threads + # + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + + ITYPE_t n_samples_chunk, chunk_size + + ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk + ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk + + bint execute_in_parallel_on_Y + + @final + cdef void _parallel_on_X(self) nogil + + @final + cdef void _parallel_on_Y(self) nogil + + # Placeholder methods which have to be implemented + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + + + # Placeholder methods which can be implemented + + cdef void compute_exact_distances(self) nogil + + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil + + cdef void _parallel_on_Y_init( + self, + ) nogil + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_Y_finalize( + self, + ) nogil diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx new file mode 100644 index 0000000000000..3b1430e09f80e --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx @@ -0,0 +1,424 @@ +cimport numpy as cnp +import numpy as np + +from sklearn import get_config +from cython cimport final +from cython.parallel cimport parallel, prange + +from ._datasets_pair cimport DatasetsPair +from ...utils._cython_blas cimport _dot +from ...utils._openmp_helpers cimport _openmp_thread_num +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +from numbers import Integral +from typing import List +from scipy.sparse import issparse +from sklearn.utils import check_scalar +from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING +from ...utils._openmp_helpers import _openmp_effective_n_threads +from ...utils._typedefs import ITYPE, DTYPE + +cnp.import_array() + +##################### + +cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * X_ptr = &X[0, 0] + ITYPE_t idx = 0 + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) + + for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): + squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + + return squared_row_norms + + + +class PairwiseDistancesReduction: + """Abstract base dispatcher for pairwise distance computation & reduction. + + Each dispatcher extending the base :class:`PairwiseDistancesReduction` + dispatcher must implement the :meth:`compute` classmethod. + """ + + @classmethod + def valid_metrics(cls) -> List[str]: + excluded = { + "pyfunc", # is relatively slow because we need to coerce data as np arrays + "mahalanobis", # is numerically unstable + # TODO: In order to support discrete distance metrics, we need to have a + # stable simultaneous sort which preserves the order of the input. + # The best might be using std::stable_sort and a Comparator taking an + # Arrays of Structures instead of Structure of Arrays (currently used). + "hamming", + *BOOL_METRICS, + } + return sorted(set(METRIC_MAPPING.keys()) - excluded) + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + """Return True if the PairwiseDistancesReduction can be used for the + given parameters. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + Returns + ------- + True if the PairwiseDistancesReduction can be used, else False. + """ + dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 + return (get_config().get("enable_cython_pairwise_dist", True) and + not issparse(X) and not issparse(Y) and dtypes_validity and + metric in cls.valid_metrics()) + +cdef class PairwiseDistancesReduction64: + """Base 64bit implementation of PairwiseDistancesReduction.""" + + def __init__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + strategy=None, + ): + cdef: + ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks + + if chunk_size is None: + chunk_size = get_config().get("pairwise_dist_chunk_size", 256) + + self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) + + self.effective_n_threads = _openmp_effective_n_threads() + + self.datasets_pair = datasets_pair + + self.n_samples_X = datasets_pair.n_samples_X() + self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) + X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk + X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk + self.X_n_chunks = X_n_full_chunks + (X_n_samples_remainder != 0) + + if X_n_samples_remainder != 0: + self.X_n_samples_last_chunk = X_n_samples_remainder + else: + self.X_n_samples_last_chunk = self.X_n_samples_chunk + + self.n_samples_Y = datasets_pair.n_samples_Y() + self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) + Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk + Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk + self.Y_n_chunks = Y_n_full_chunks + (Y_n_samples_remainder != 0) + + if Y_n_samples_remainder != 0: + self.Y_n_samples_last_chunk = Y_n_samples_remainder + else: + self.Y_n_samples_last_chunk = self.Y_n_samples_chunk + + if strategy is None: + strategy = get_config().get("pairwise_dist_parallel_strategy", 'auto') + + if strategy not in ('parallel_on_X', 'parallel_on_Y', 'auto'): + raise RuntimeError(f"strategy must be 'parallel_on_X, 'parallel_on_Y', " + f"or 'auto', but currently strategy='{self.strategy}'.") + + if strategy == 'auto': + # This is a simple heuristic whose constant for the + # comparison has been chosen based on experiments. + if 4 * self.chunk_size * self.effective_n_threads < self.n_samples_X: + strategy = 'parallel_on_X' + else: + strategy = 'parallel_on_Y' + + self.execute_in_parallel_on_Y = strategy == "parallel_on_Y" + + # Not using less, not using more. + self.chunks_n_threads = min( + self.Y_n_chunks if self.execute_in_parallel_on_Y else self.X_n_chunks, + self.effective_n_threads, + ) + + @final + cdef void _parallel_on_X(self) nogil: + """Compute the pairwise distances of each row vector of X on Y + by parallelizing computation on the outer loop on chunks of X + and reduce them. + + This strategy dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + thread_num = _openmp_thread_num() + + # Allocating thread datastructures + self._parallel_on_X_parallel_init(thread_num) + + for X_chunk_idx in prange(self.X_n_chunks, schedule='static'): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1: + X_end = X_start + self.X_n_samples_last_chunk + else: + X_end = X_start + self.X_n_samples_chunk + + # Reinitializing thread datastructures for the new X chunk + # If necessary, upcast X[X_start:X_end] to 64bit + self._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + for Y_chunk_idx in range(self.Y_n_chunks): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1: + Y_end = Y_start + self.Y_n_samples_last_chunk + else: + Y_end = Y_start + self.Y_n_samples_chunk + + # If necessary, upcast Y[Y_start:Y_end] to 64bit + self._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + + self._compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + + # Adjusting thread datastructures on the full pass on Y + self._parallel_on_X_prange_iter_finalize(thread_num, X_start, X_end) + + # end: for X_chunk_idx + + # Deallocating thread datastructures + self._parallel_on_X_parallel_finalize(thread_num) + + # end: with nogil, parallel + return + + @final + cdef void _parallel_on_Y(self) nogil: + """Compute the pairwise distances of each row vector of X on Y + by parallelizing computation on the inner loop on chunks of Y + and reduce them. + + This strategy dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t thread_num + + # Allocating datastructures shared by all threads + self._parallel_on_Y_init() + + for X_chunk_idx in range(self.X_n_chunks): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1: + X_end = X_start + self.X_n_samples_last_chunk + else: + X_end = X_start + self.X_n_samples_chunk + + with nogil, parallel(num_threads=self.chunks_n_threads): + thread_num = _openmp_thread_num() + + # Initializing datastructures used in this thread + # If necessary, upcast X[X_start:X_end] to 64bit + self._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1: + Y_end = Y_start + self.Y_n_samples_last_chunk + else: + Y_end = Y_start + self.Y_n_samples_chunk + + # If necessary, upcast Y[Y_start:Y_end] to 64bit + self._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + + self._compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + # end: prange + + # Note: we don't need a _parallel_on_Y_finalize similarly. + # This can be introduced if needed. + + # end: with nogil, parallel + + # Synchronizing the thread datastructures with the main ones + self._parallel_on_Y_synchronize(X_start, X_end) + + # end: for X_chunk_idx + # Deallocating temporary datastructures and adjusting main datastructures + self._parallel_on_Y_finalize() + return + + # Placeholder methods which have to be implemented + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + """Compute the pairwise distances on two chunks of X and Y and reduce them. + + This is THE core computational method of PairwiseDistanceReductions64. + This must be implemented in subclasses agnostically from the parallelization + strategies. + """ + return + + def _finalize_results(self, bint return_distance): + """Callback adapting datastructures before returning results. + + This must be implemented in subclasses. + """ + return None + + # Placeholder methods which can be implemented + + cdef void compute_exact_distances(self) nogil: + """Convert rank-preserving distances to exact distances or recompute them.""" + return + + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + """Allocate datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. + + This is eventually used to upcast X[X_start:X_end] to 64bit. + """ + return + + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Interact with datastructures after a reduction on chunks.""" + return + + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + """Interact with datastructures after executing all the reductions.""" + return + + cdef void _parallel_on_Y_init( + self, + ) nogil: + """Allocate datastructures used in all threads.""" + return + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + """Initialise datastructures just before the _compute_and_reduce_distances_on_chunks. + + This is eventually used to upcast Y[Y_start:Y_end] to 64bit. + """ + return + + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Update thread datastructures before leaving a parallel region.""" + return + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + """Update datastructures after executing all the reductions.""" + return diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd new file mode 100644 index 0000000000000..de6458f8c6f26 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd @@ -0,0 +1,21 @@ +from ...utils._typedefs cimport DTYPE_t, ITYPE_t +from ...metrics._dist_metrics cimport DistanceMetric + + +cdef class DatasetsPair: + cdef DistanceMetric distance_metric + + cdef ITYPE_t n_samples_X(self) nogil + + cdef ITYPE_t n_samples_Y(self) nogil + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil + + +cdef class DenseDenseDatasetsPair(DatasetsPair): + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + ITYPE_t d diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx new file mode 100644 index 0000000000000..abef1bed098ed --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx @@ -0,0 +1,164 @@ +import numpy as np +cimport numpy as cnp + +from cython cimport final +from scipy.sparse import issparse + +from ...utils._typedefs cimport DTYPE_t, ITYPE_t +from ...metrics._dist_metrics cimport DistanceMetric + +cnp.import_array() + +cdef class DatasetsPair: + """Abstract class which wraps a pair of datasets (X, Y). + + This class allows computing distances between a single pair of rows of + of X and Y at a time given the pair of their indices (i, j). This class is + specialized for each metric thanks to the :func:`get_for` factory classmethod. + + The handling of parallelization over chunks to compute the distances + and aggregation for several rows at a time is done in dedicated + subclasses of PairwiseDistancesReduction that in-turn rely on + subclasses of DatasetsPair for each pair of rows in the data. The goal + is to make it possible to decouple the generic parallelization and + aggregation logic from metric-specific computation as much as + possible. + + X and Y can be stored as C-contiguous np.ndarrays or CSR matrices + in subclasses. + + This class avoids the overhead of dispatching distance computations + to :class:`sklearn.metrics.DistanceMetric` based on the physical + representation of the vectors (sparse vs. dense). It makes use of + cython.final to remove the overhead of dispatching method calls. + + Parameters + ---------- + distance_metric: DistanceMetric + The distance metric responsible for computing distances + between two vectors of (X, Y). + """ + + @classmethod + def get_for( + cls, + X, + Y, + str metric="euclidean", + dict metric_kwargs=None, + ) -> DatasetsPair: + """Return the DatasetsPair implementation for the given arguments. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + If provided as a ndarray, it must be C-contiguous. + If provided as a sparse matrix, it must be in CSR format. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + If provided as a ndarray, it must be C-contiguous. + If provided as a sparse matrix, it must be in CSR format. + + metric : str, default='euclidean' + The distance metric to compute between rows of X and Y. + The default metric is a fast implementation of the Euclidean + metric. For a list of available metrics, see the documentation + of :class:`~sklearn.metrics.DistanceMetric`. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + Returns + ------- + datasets_pair: DatasetsPair + The suited DatasetsPair implementation. + """ + cdef: + DistanceMetric distance_metric = DistanceMetric.get_metric( + metric, + **(metric_kwargs or {}) + ) + + if not(X.dtype == Y.dtype == np.float64): + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + # Metric-specific checks that do not replace nor duplicate `check_array`. + distance_metric._validate_data(X) + distance_metric._validate_data(Y) + + # TODO: dispatch to other dataset pairs for sparse support once available: + if issparse(X) or issparse(Y): + raise ValueError("Only dense datasets are supported for X and Y.") + + return DenseDenseDatasetsPair(X, Y, distance_metric) + + def __init__(self, DistanceMetric distance_metric): + self.distance_metric = distance_metric + + cdef ITYPE_t n_samples_X(self) nogil: + """Number of samples in X.""" + # This is a abstract method. + # This _must_ always be overwritten in subclasses. + # TODO: add "with gil: raise" here when supporting Cython 3.0 + return -999 + + cdef ITYPE_t n_samples_Y(self) nogil: + """Number of samples in Y.""" + # This is a abstract method. + # This _must_ always be overwritten in subclasses. + # TODO: add "with gil: raise" here when supporting Cython 3.0 + return -999 + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.dist(i, j) + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: + # This is a abstract method. + # This _must_ always be overwritten in subclasses. + # TODO: add "with gil: raise" here when supporting Cython 3.0 + return -1 + +@final +cdef class DenseDenseDatasetsPair(DatasetsPair): + """Compute distances between row vectors of two arrays. + + Parameters + ---------- + X: ndarray of shape (n_samples_X, n_features) + Rows represent vectors. Must be C-contiguous. + + Y: ndarray of shape (n_samples_Y, n_features) + Rows represent vectors. Must be C-contiguous. + + distance_metric: DistanceMetric + The distance metric responsible for computing distances + between two row vectors of (X, Y). + """ + + def __init__(self, X, Y, DistanceMetric distance_metric): + super().__init__(distance_metric) + # Arrays have already been checked + self.X = X + self.Y = Y + self.d = X.shape[1] + + @final + cdef ITYPE_t n_samples_X(self) nogil: + return self.X.shape[0] + + @final + cdef ITYPE_t n_samples_Y(self) nogil: + return self.Y.shape[0] + + @final + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.distance_metric.rdist(&self.X[i, 0], &self.Y[j, 0], self.d) + + @final + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.distance_metric.dist(&self.X[i, 0], &self.Y[j, 0], self.d) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd new file mode 100644 index 0000000000000..a1c5bd3a8d80c --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd @@ -0,0 +1,62 @@ +from ...utils._typedefs cimport DTYPE_t, ITYPE_t +from libcpp.vector cimport vector + + +cdef class GEMMTermComputer64: + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + ITYPE_t dist_middle_terms_chunks_size + ITYPE_t n_features + ITYPE_t chunk_size + + # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM + vector[vector[DTYPE_t]] dist_middle_terms_chunks + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + + cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_Y_init(self) nogil + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num + ) nogil + + cdef DTYPE_t * _compute_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil diff --git a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx new file mode 100644 index 0000000000000..b4281d27bb2eb --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx @@ -0,0 +1,135 @@ +from libcpp.vector cimport vector + +from ...utils._typedefs cimport DTYPE_t, ITYPE_t + +from ...utils._cython_blas cimport ( + BLAS_Order, + BLAS_Trans, + ColMajor, + NoTrans, + RowMajor, + Trans, + _gemm, +) + +cdef class GEMMTermComputer64: + """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. + + `FastEuclidean*` classes internally compute the squared Euclidean distances between + chunks of vectors X_c and Y_c using the following decomposition: + + + ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + + + This helper class is in charge of wrapping the common logic to compute + the middle term `- 2 X_c_i.Y_c_j^T` with a call to GEMM, which has a high + arithmetic intensity. + """ + + def __init__(self, + DTYPE_t[:, ::1] X, + DTYPE_t[:, ::1] Y, + ITYPE_t effective_n_threads, + ITYPE_t chunks_n_threads, + ITYPE_t dist_middle_terms_chunks_size, + ITYPE_t n_features, + ITYPE_t chunk_size, + ): + self.X = X + self.Y = Y + self.effective_n_threads = effective_n_threads + self.chunks_n_threads = chunks_n_threads + self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size + self.n_features = n_features + self.chunk_size = chunk_size + + self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) + + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + return + + cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: + self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _parallel_on_Y_init(self) nogil: + for thread_num in range(self.chunks_n_threads): + self.dist_middle_terms_chunks[thread_num].resize( + self.dist_middle_terms_chunks_size + ) + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + return + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num + ) nogil: + return + + cdef DTYPE_t * _compute_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * A = &X_c[0, 0] + DTYPE_t * B = &Y_c[0, 0] + ITYPE_t lda = X_c.shape[1] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) + + return dist_middle_terms diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd new file mode 100644 index 0000000000000..29630a26d0d27 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd @@ -0,0 +1,89 @@ +cimport numpy as cnp + +from libcpp.memory cimport shared_ptr +from libcpp.vector cimport vector +from cython cimport final + +from ._base cimport ( + PairwiseDistancesReduction64, +) +from ._gemm_term_computer cimport GEMMTermComputer64 + +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +cnp.import_array() + +###################### +## std::vector to np.ndarray coercion +# As type covariance is not supported for C++ containers via Cython, +# we need to redefine fused types. +ctypedef fused vector_DITYPE_t: + vector[ITYPE_t] + vector[DTYPE_t] + + +ctypedef fused vector_vector_DITYPE_t: + vector[vector[ITYPE_t]] + vector[vector[DTYPE_t]] + +cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( + shared_ptr[vector_vector_DITYPE_t] vecs +) + +##################### + +cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistancesArgKmin.""" + + cdef: + DTYPE_t radius + + # DistanceMetric compute rank-preserving surrogate distance via rdist + # which are proxies necessitating less computations. + # We get the equivalent for the radius to be able to compare it against + # vectors' rank-preserving surrogate distances. + DTYPE_t r_radius + + # Neighbors indices and distances are returned as np.ndarrays of np.ndarrays. + # + # For this implementation, we want resizable buffers which we will wrap + # into numpy arrays at the end. std::vector comes as a handy container + # for interacting efficiently with resizable buffers. + # + # Though it is possible to access their buffer address with + # std::vector::data, they can't be stolen: buffers lifetime + # is tied to their std::vector and are deallocated when + # std::vectors are. + # + # To solve this, we dynamically allocate std::vectors and then + # encapsulate them in a StdVectorSentinel responsible for + # freeing them when the associated np.ndarray is freed. + # + # Shared pointers (defined via shared_ptr) are use for safer memory management. + # Unique pointers (defined via unique_ptr) can't be used as datastructures + # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. + shared_ptr[vector[vector[ITYPE_t]]] neigh_indices + shared_ptr[vector[vector[DTYPE_t]]] neigh_distances + + # Used as array of pointers to private datastructures used in threads. + vector[shared_ptr[vector[vector[ITYPE_t]]]] neigh_indices_chunks + vector[shared_ptr[vector[vector[DTYPE_t]]]] neigh_distances_chunks + + bint sort_results + + @final + cdef void _merge_vectors( + self, + ITYPE_t idx, + ITYPE_t num_threads, + ) nogil + + +cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): + """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" + cdef: + GEMMTermComputer64 gemm_term_computer + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + bint use_squared_distances diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx new file mode 100644 index 0000000000000..a42b3aa5c7f50 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx @@ -0,0 +1,638 @@ +cimport numpy as cnp +import numpy as np +import warnings + +from libcpp.memory cimport shared_ptr, make_shared +from libcpp.vector cimport vector +from cython cimport final +from cython.operator cimport dereference as deref +from cython.parallel cimport parallel, prange + +from ._base cimport ( + PairwiseDistancesReduction64, + _sqeuclidean_row_norms64 +) + +from ._datasets_pair cimport ( + DatasetsPair, + DenseDenseDatasetsPair, +) + +from ._gemm_term_computer cimport GEMMTermComputer64 + +from ...utils._sorting cimport simultaneous_sort +from ...utils._typedefs cimport ITYPE_t, DTYPE_t +from ...utils._vector_sentinel cimport vector_to_nd_array + +from numbers import Real +from scipy.sparse import issparse +from sklearn.utils import check_scalar, _in_unstable_openblas_configuration +from sklearn.utils.fixes import threadpool_limits + +from ._base import PairwiseDistancesReduction + +cnp.import_array() + +# TODO: change for `libcpp.algorithm.move` once Cython 3 is used +# Introduction in Cython: +# https://github.com/cython/cython/blob/05059e2a9b89bf6738a7750b905057e5b1e3fe2e/Cython/Includes/libcpp/algorithm.pxd#L47 #noqa +cdef extern from "" namespace "std" nogil: + OutputIt move[InputIt, OutputIt](InputIt first, InputIt last, OutputIt d_first) except + #noqa + +###################### + +cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( + shared_ptr[vector_vector_DITYPE_t] vecs +): + """Coerce a std::vector of std::vector to a ndarray of ndarray.""" + cdef: + ITYPE_t n = deref(vecs).size() + cnp.ndarray[object, ndim=1] nd_arrays_of_nd_arrays = np.empty(n, dtype=np.ndarray) + + for i in range(n): + nd_arrays_of_nd_arrays[i] = vector_to_nd_array(&(deref(vecs)[i])) + + return nd_arrays_of_nd_arrays + +##################### + +class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): + """Compute radius-based neighbors for two sets of vectors. + + For each row-vector X[i] of the queries X, find all the indices j of + row-vectors in Y such that: + + dist(X[i], Y[j]) <= radius + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def compute( + cls, + X, + Y, + DTYPE_t radius, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + bint return_distance=False, + bint sort_results=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + radius : float + The radius defining the neighborhood. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its neighbors if set to True. + + sort_results : boolean, default=False + Sort results with respect to distances between each X vector and its + neighbors if set to True. + + Returns + ------- + If return_distance=False: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + + If return_distance=True: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + - neighbors_distances : ndarray of n_samples_X ndarray + Distances to the neighbors for each vector in X. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistancesRadiusNeighborhood`. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the API entirely from the + implementation details whilst maintaining RAII. + """ + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesRadiusNeighborhood64.compute( + X=X, + Y=Y, + radius=radius, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + sort_results=sort_results, + return_distance=return_distance, + ) + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + +cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): + """64bit implementation of PairwiseDistancesArgKmin.""" + + @classmethod + def compute( + cls, + X, + Y, + DTYPE_t radius, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + bint return_distance=False, + bint sort_results=False, + ): + """Compute the radius-neighbors reduction. + + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesRadiusNeighborhood64`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + + No instance should directly be created outside of this class method. + """ + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + # Specialized implementation with improved arithmetic intensity + # and vector instructions (SIMD) by processing several vectors + # at time to leverage a call to the BLAS GEMM routine as explained + # in more details in the docstring. + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesRadiusNeighborhood64( + X=X, Y=Y, radius=radius, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + sort_results=sort_results, + ) + else: + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pda = PairwiseDistancesRadiusNeighborhood64( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + radius=radius, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + sort_results=sort_results, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pda.execute_in_parallel_on_Y: + pda._parallel_on_Y() + else: + pda._parallel_on_X() + + return pda._finalize_results(return_distance) + + + def __init__( + self, + DatasetsPair datasets_pair, + DTYPE_t radius, + chunk_size=None, + strategy=None, + sort_results=False, + metric_kwargs=None, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + strategy=strategy, + ) + + self.radius = check_scalar(radius, "radius", Real, min_val=0) + self.r_radius = self.datasets_pair.distance_metric._dist_to_rdist(radius) + self.sort_results = sort_results + + # Allocating pointers to datastructures but not the datastructures themselves. + # There are as many pointers as effective threads. + # + # For the sake of explicitness: + # - when parallelizing on X, the pointers of those heaps are referencing + # self.neigh_distances and self.neigh_indices + # - when parallelizing on Y, the pointers of those heaps are referencing + # std::vectors of std::vectors which are thread-wise-allocated and whose + # content will be merged into self.neigh_distances and self.neigh_indices. + self.neigh_distances_chunks = vector[shared_ptr[vector[vector[DTYPE_t]]]]( + self.chunks_n_threads + ) + self.neigh_indices_chunks = vector[shared_ptr[vector[vector[ITYPE_t]]]]( + self.chunks_n_threads + ) + + # Temporary datastructures which will be coerced to numpy arrays on before + # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. + self.neigh_distances = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) + self.neigh_indices = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t r_dist_i_j + + for i in range(X_start, X_end): + for j in range(Y_start, Y_end): + r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) + if r_dist_i_j <= self.r_radius: + deref(self.neigh_distances_chunks[thread_num])[i].push_back(r_dist_i_j) + deref(self.neigh_indices_chunks[thread_num])[i].push_back(j) + + def _finalize_results(self, bint return_distance=False): + if return_distance: + # We need to recompute distances because we relied on + # surrogate distances for the reduction. + self.compute_exact_distances() + return ( + coerce_vectors_to_nd_arrays(self.neigh_distances), + coerce_vectors_to_nd_arrays(self.neigh_indices), + ) + + return coerce_vectors_to_nd_arrays(self.neigh_indices) + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + + # As this strategy is embarrassingly parallel, we can set the + # thread vectors' pointers to the main vectors'. + self.neigh_distances_chunks[thread_num] = self.neigh_distances + self.neigh_indices_chunks[thread_num] = self.neigh_indices + + @final + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + + # Sorting neighbors for each query vector of X + if self.sort_results: + for idx in range(X_start, X_end): + simultaneous_sort( + deref(self.neigh_distances)[idx].data(), + deref(self.neigh_indices)[idx].data(), + deref(self.neigh_indices)[idx].size() + ) + + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef: + ITYPE_t thread_num + # As chunks of X are shared across threads, so must datastructures to avoid race + # conditions: each thread has its own vectors of n_samples_X vectors which are + # then merged back in the main n_samples_X vectors. + for thread_num in range(self.chunks_n_threads): + self.neigh_distances_chunks[thread_num] = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) + self.neigh_indices_chunks[thread_num] = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) + + @final + cdef void _merge_vectors( + self, + ITYPE_t idx, + ITYPE_t num_threads, + ) nogil: + cdef: + ITYPE_t thread_num + ITYPE_t idx_n_elements = 0 + ITYPE_t last_element_idx = deref(self.neigh_indices)[idx].size() + + # Resizing buffers only once for the given number of elements. + for thread_num in range(num_threads): + idx_n_elements += deref(self.neigh_distances_chunks[thread_num])[idx].size() + + deref(self.neigh_distances)[idx].resize(last_element_idx + idx_n_elements) + deref(self.neigh_indices)[idx].resize(last_element_idx + idx_n_elements) + + # Moving the elements by range using the range first element + # as the reference for the insertion. + for thread_num in range(num_threads): + move( + deref(self.neigh_distances_chunks[thread_num])[idx].begin(), + deref(self.neigh_distances_chunks[thread_num])[idx].end(), + deref(self.neigh_distances)[idx].begin() + last_element_idx + ) + move( + deref(self.neigh_indices_chunks[thread_num])[idx].begin(), + deref(self.neigh_indices_chunks[thread_num])[idx].end(), + deref(self.neigh_indices)[idx].begin() + last_element_idx + ) + last_element_idx += deref(self.neigh_distances_chunks[thread_num])[idx].size() + + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef: + ITYPE_t idx, jdx, thread_num, idx_n_element, idx_current + + with nogil, parallel(num_threads=self.effective_n_threads): + # Merge vectors used in threads into the main ones. + # This is done in parallel sample-wise (no need for locks) + # using dynamic scheduling because we might not have + # the same number of neighbors for each query vector. + for idx in prange(self.n_samples_X, schedule='static'): + self._merge_vectors(idx, self.chunks_n_threads) + + # The content of the vector have been std::moved. + # Hence they can't be used anymore and can be deleted. + # Their deletion is carried out automatically as the + # implementation relies on shared pointers. + + # Sort in parallel in ascending order w.r.t the distances if requested. + if self.sort_results: + for idx in prange(self.n_samples_X, schedule='static'): + simultaneous_sort( + deref(self.neigh_distances)[idx].data(), + deref(self.neigh_indices)[idx].data(), + deref(self.neigh_indices)[idx].size() + ) + + return + + cdef void compute_exact_distances(self) nogil: + """Convert rank-preserving distances to pairwise distances in parallel.""" + cdef: + ITYPE_t i, j + + for i in prange(self.n_samples_X, nogil=True, schedule='static', + num_threads=self.effective_n_threads): + for j in range(deref(self.neigh_indices)[i].size()): + deref(self.neigh_distances)[i][j] = ( + self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against eventual -0., causing nan production. + max(deref(self.neigh_distances)[i][j], 0.) + ) + ) + + +cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): + """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistancesRadiusNeighborhood64.is_usable_for(X, Y, metric) + and not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + DTYPE_t radius, + bint use_squared_distances=False, + chunk_size=None, + strategy=None, + sort_results=False, + metric_kwargs=None, + ): + if ( + metric_kwargs is not None and + len(metric_kwargs) > 0 and + "Y_norm_squared" not in metric_kwargs + ): + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " + f"usable for this case (FastEuclideanPairwiseDistancesRadiusNeighborhood) and will be ignored.", + UserWarning, + stacklevel=3, + ) + + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + radius=radius, + chunk_size=chunk_size, + strategy=strategy, + sort_results=sort_results, + metric_kwargs=metric_kwargs, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = self.datasets_pair + ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk + + self.gemm_term_computer = GEMMTermComputer64( + datasets_pair.X, + datasets_pair.Y, + self.effective_n_threads, + self.chunks_n_threads, + dist_middle_terms_chunks_size, + n_features=datasets_pair.X.shape[1], + chunk_size=self.chunk_size, + ) + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + else: + self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + if use_squared_distances: + # In this specialisation and this setup, the value passed to the radius is + # already considered to be the adapted radius, so we overwrite it. + self.r_radius = radius + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_X_parallel_init(self, thread_num) + self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_init(self) + self.gemm_term_computer._parallel_on_Y_init() + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesRadiusNeighborhood64.compute_exact_distances(self) + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + + # Pushing the distance and their associated indices in vectors. + for i in range(n_X): + for j in range(n_Y): + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + squared_dist_i_j = ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * n_Y + j] + + self.Y_norm_squared[j + Y_start] + ) + if squared_dist_i_j <= self.r_radius: + deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) + deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) diff --git a/sklearn/metrics/_pairwise_distances_reduction/setup.py b/sklearn/metrics/_pairwise_distances_reduction/setup.py new file mode 100644 index 0000000000000..ddde0717835da --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/setup.py @@ -0,0 +1,40 @@ +import os + +import numpy as np +from numpy.distutils.misc_util import Configuration + + +def configuration(parent_package="", top_path=None): + config = Configuration("_pairwise_distances_reduction", parent_package, top_path) + libraries = [] + if os.name == "posix": + libraries.append("m") + + cython_sources = [ + "_datasets_pair.pyx", + "_gemm_term_computer.pyx", + "_base.pyx", + "_argkmin.pyx", + "_radius_neighborhood.pyx", + ] + + for source_file in cython_sources: + private_extension_name = source_file.replace(".pyx", "") + config.add_extension( + name=private_extension_name, + sources=[source_file], + include_dirs=[np.get_include()], + language="c++", + libraries=libraries, + extra_compile_args=["-std=c++11"], + ) + + config.add_subpackage("tests") + + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + + setup(**configuration().todict()) diff --git a/sklearn/metrics/_pairwise_distances_reduction/tests/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py similarity index 98% rename from sklearn/metrics/tests/test_pairwise_distances_reduction.py rename to sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py index b47407f3754ee..6b63582c4e935 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py @@ -20,6 +20,7 @@ from sklearn.utils._testing import ( assert_array_equal, assert_allclose, + assert_radius_neighborhood_results_equality, ) # Common supported metric between scipy.spatial.distance.cdist @@ -177,25 +178,6 @@ def assert_argkmin_results_quasi_equality( ), msg -def assert_radius_neighborhood_results_equality( - ref_dist, dist, ref_indices, indices, radius -): - # We get arrays of arrays and we need to check for individual pairs - for i in range(ref_dist.shape[0]): - assert (ref_dist[i] <= radius).all() - assert_array_equal( - ref_indices[i], - indices[i], - err_msg=f"Query vector #{i} has different neighbors' indices", - ) - assert_allclose( - ref_dist[i], - dist[i], - err_msg=f"Query vector #{i} has different neighbors' distances", - rtol=1e-7, - ) - - def assert_radius_neighborhood_results_quasi_equality( ref_dist, dist, @@ -955,7 +937,7 @@ def test_pairwise_distances_radius_neighbors( @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("n_features", [5, 10, 100]) @pytest.mark.parametrize("num_threads", [1, 2, 8]) -def test_sqeuclidean_row_norms( +def test_sqeuclidean_row_norms64( global_random_seed, n_samples, n_features, diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index fc912068cb6c4..e6e13a8c3e030 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -16,6 +16,7 @@ def configuration(parent_package="", top_path=None): config.add_subpackage("_plot") config.add_subpackage("_plot.tests") config.add_subpackage("cluster") + config.add_subpackage("_pairwise_distances_reduction") config.add_extension( "_pairwise_fast", sources=["_pairwise_fast.pyx"], libraries=libraries @@ -35,15 +36,6 @@ def configuration(parent_package="", top_path=None): libraries=libraries, ) - config.add_extension( - "_pairwise_distances_reduction", - sources=["_pairwise_distances_reduction.pyx"], - include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")], - language="c++", - libraries=libraries, - extra_compile_args=["-std=c++11"], - ) - config.add_subpackage("tests") return config diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index faffa8bf85265..da548e4a9f046 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -28,9 +28,6 @@ from sklearn.exceptions import NotFittedError from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS -from sklearn.metrics.tests.test_pairwise_distances_reduction import ( - assert_radius_neighborhood_results_equality, -) from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split from sklearn.neighbors import ( @@ -47,6 +44,7 @@ from sklearn.utils._testing import ( assert_allclose, assert_array_equal, + assert_radius_neighborhood_results_equality, ) from sklearn.utils._testing import ignore_warnings from sklearn.utils.validation import check_random_state diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 8a94b1f31abee..4158f7fa65c83 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -1151,3 +1151,22 @@ def transform(self, X, y=None): def fit_transform(self, X, y=None): return self.fit(X, y).transform(X, y) + + +def assert_radius_neighborhood_results_equality( + ref_dist, dist, ref_indices, indices, radius +): + # We get arrays of arrays and we need to check for individual pairs + for i in range(ref_dist.shape[0]): + assert (ref_dist[i] <= radius).all() + assert_array_equal( + ref_indices[i], + indices[i], + err_msg=f"Query vector #{i} has different neighbors' indices", + ) + assert_allclose( + ref_dist[i], + dist[i], + err_msg=f"Query vector #{i} has different neighbors' distances", + rtol=1e-7, + ) From 5c855fe102b15dcbbf976e43f4261f2530596126 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 10 Jun 2022 11:39:27 +0200 Subject: [PATCH 09/19] MAINT Group dispatchers in a single file --- .../_pairwise_distances_reduction/__init__.py | 10 +- .../_argkmin.pyx | 124 ------- .../_pairwise_distances_reduction/_base.pyx | 52 --- .../_dispatcher.py | 323 ++++++++++++++++++ .../_radius_neighborhood.pyx | 133 -------- 5 files changed, 328 insertions(+), 314 deletions(-) create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 947f25c6c71e9..76c62c09e7287 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -86,14 +86,14 @@ # Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). -from ._base import ( +from ._base import _sqeuclidean_row_norms64 + +from ._dispatcher import ( PairwiseDistancesReduction, - _sqeuclidean_row_norms64, + PairwiseDistancesArgKmin, + PairwiseDistancesRadiusNeighborhood, ) -from ._argkmin import PairwiseDistancesArgKmin -from ._radius_neighborhood import PairwiseDistancesRadiusNeighborhood - __all__ = [ "PairwiseDistancesReduction", "PairwiseDistancesArgKmin", diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx index f202ec37395fa..2f378543e1f97 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx @@ -32,130 +32,6 @@ from ...utils._typedefs import ITYPE, DTYPE cnp.import_array() -from ._base import PairwiseDistancesReduction - -class PairwiseDistancesArgKmin(PairwiseDistancesReduction): - """Compute the argkmin of row vectors of X on the ones of Y. - - For each row vector of X, computes the indices of k first the rows - vectors of Y with the smallest distances. - - PairwiseDistancesArgKmin is typically used to perform - bruteforce k-nearest neighbors queries. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - ITYPE_t k, - str metric="euclidean", - chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - ): - """Compute the argkmin reduction. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - k : int - The k for the argkmin reduction. - - metric : str, default='euclidean' - The distance metric to use for argkmin. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its - argkmin if set to True. - - Returns - ------- - If return_distance=False: - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - If return_distance=True: - - argkmin_distances : ndarray of shape (n_samples_X, k) - Distances to the argkmin for each vector in X. - - argkmin_indices : ndarray of shape (n_samples_X, k) - Indices of the argkmin for each vector in X. - - Notes - ----- - This classmethod is responsible for introspecting the arguments - values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin`. - - This allows decoupling the API entirely from the implementation details - whilst maintaining RAII: all temporarily allocated datastructures necessary - for the concrete implementation are therefore freed when this classmethod - returns. - """ - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesArgKmin64.compute( - X=X, - Y=Y, - k=k, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): """64bit implementation of PairwiseDistancesArgKmin.""" diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx index 3b1430e09f80e..07506e3616a74 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx @@ -11,10 +11,7 @@ from ...utils._openmp_helpers cimport _openmp_thread_num from ...utils._typedefs cimport ITYPE_t, DTYPE_t from numbers import Integral -from typing import List -from scipy.sparse import issparse from sklearn.utils import check_scalar -from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING from ...utils._openmp_helpers import _openmp_effective_n_threads from ...utils._typedefs import ITYPE, DTYPE @@ -47,55 +44,6 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( return squared_row_norms - -class PairwiseDistancesReduction: - """Abstract base dispatcher for pairwise distance computation & reduction. - - Each dispatcher extending the base :class:`PairwiseDistancesReduction` - dispatcher must implement the :meth:`compute` classmethod. - """ - - @classmethod - def valid_metrics(cls) -> List[str]: - excluded = { - "pyfunc", # is relatively slow because we need to coerce data as np arrays - "mahalanobis", # is numerically unstable - # TODO: In order to support discrete distance metrics, we need to have a - # stable simultaneous sort which preserves the order of the input. - # The best might be using std::stable_sort and a Comparator taking an - # Arrays of Structures instead of Structure of Arrays (currently used). - "hamming", - *BOOL_METRICS, - } - return sorted(set(METRIC_MAPPING.keys()) - excluded) - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - """Return True if the PairwiseDistancesReduction can be used for the - given parameters. - - Parameters - ---------- - X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) - Input data. - - Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) - Input data. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - Returns - ------- - True if the PairwiseDistancesReduction can be used, else False. - """ - dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 - return (get_config().get("enable_cython_pairwise_dist", True) and - not issparse(X) and not issparse(Y) and dtypes_validity and - metric in cls.valid_metrics()) - cdef class PairwiseDistancesReduction64: """Base 64bit implementation of PairwiseDistancesReduction.""" diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py new file mode 100644 index 0000000000000..e4fea34125b96 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -0,0 +1,323 @@ +import numpy as np + +from typing import List +from scipy.sparse import issparse +from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING + +from ._argkmin import PairwiseDistancesArgKmin64 +from ._radius_neighborhood import PairwiseDistancesRadiusNeighborhood64 + +from ... import get_config + + +class PairwiseDistancesReduction: + """Abstract base dispatcher for pairwise distance computation & reduction. + + Each dispatcher extending the base :class:`PairwiseDistancesReduction` + dispatcher must implement the :meth:`compute` classmethod. + """ + + @classmethod + def valid_metrics(cls) -> List[str]: + excluded = { + "pyfunc", # is relatively slow because we need to coerce data as np arrays + "mahalanobis", # is numerically unstable + # TODO: In order to support discrete distance metrics, we need to have a + # stable simultaneous sort which preserves the order of the input. + # The best might be using std::stable_sort and a Comparator taking an + # Arrays of Structures instead of Structure of Arrays (currently used). + "hamming", + *BOOL_METRICS, + } + return sorted(set(METRIC_MAPPING.keys()) - excluded) + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + """Return True if the PairwiseDistancesReduction can be used for the + given parameters. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + Returns + ------- + True if the PairwiseDistancesReduction can be used, else False. + """ + dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 + return ( + get_config().get("enable_cython_pairwise_dist", True) + and not issparse(X) + and not issparse(Y) + and dtypes_validity + and metric in cls.valid_metrics() + ) + + +class PairwiseDistancesArgKmin(PairwiseDistancesReduction): + """Compute the argkmin of row vectors of X on the ones of Y. + + For each row vector of X, computes the indices of k first the rows + vectors of Y with the smallest distances. + + PairwiseDistancesArgKmin is typically used to perform + bruteforce k-nearest neighbors queries. + + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def compute( + cls, + X, + Y, + k, + metric="euclidean", + chunk_size=None, + metric_kwargs=None, + strategy=None, + return_distance=False, + ): + """Compute the argkmin reduction. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + k : int + The k for the argkmin reduction. + + metric : str, default='euclidean' + The distance metric to use for argkmin. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures + synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. + When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' + brings more opportunity for parallelism and is therefore more efficient + despite the synchronization step at each iteration of the outer loop + on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + If return_distance=False: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + If return_distance=True: + - argkmin_distances : ndarray of shape (n_samples_X, k) + Distances to the argkmin for each vector in X. + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + Notes + ----- + This classmethod is responsible for introspecting the arguments + values to dispatch to the most appropriate implementation of + :class:`PairwiseDistancesArgKmin`. + + This allows decoupling the API entirely from the implementation details + whilst maintaining RAII: all temporarily allocated datastructures necessary + for the concrete implementation are therefore freed when this classmethod + returns. + """ + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesArgKmin64.compute( + X=X, + Y=Y, + k=k, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + return_distance=return_distance, + ) + raise ValueError( + "Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) + + +class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): + """Compute radius-based neighbors for two sets of vectors. + + For each row-vector X[i] of the queries X, find all the indices j of + row-vectors in Y such that: + + dist(X[i], Y[j]) <= radius + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + This class is not meant to be instanciated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def compute( + cls, + X, + Y, + radius, + metric="euclidean", + chunk_size=None, + metric_kwargs=None, + strategy=None, + return_distance=False, + sort_results=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + radius : float + The radius defining the neighborhood. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures + synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. + When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' + brings more opportunity for parallelism and is therefore more efficient + despite the synchronization step at each iteration of the outer loop + on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its neighbors if set to True. + + sort_results : boolean, default=False + Sort results with respect to distances between each X vector and its + neighbors if set to True. + + Returns + ------- + If return_distance=False: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + + If return_distance=True: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + - neighbors_distances : ndarray of n_samples_X ndarray + Distances to the neighbors for each vector in X. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private dtype-specialized implementation of + :class:`PairwiseDistancesRadiusNeighborhood`. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the API entirely from the + implementation details whilst maintaining RAII. + """ + if X.dtype == Y.dtype == np.float64: + return PairwiseDistancesRadiusNeighborhood64.compute( + X=X, + Y=Y, + radius=radius, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + sort_results=sort_results, + return_distance=return_distance, + ) + raise ValueError( + "Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx index a42b3aa5c7f50..8ac4951cb5820 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx @@ -29,8 +29,6 @@ from scipy.sparse import issparse from sklearn.utils import check_scalar, _in_unstable_openblas_configuration from sklearn.utils.fixes import threadpool_limits -from ._base import PairwiseDistancesReduction - cnp.import_array() # TODO: change for `libcpp.algorithm.move` once Cython 3 is used @@ -56,137 +54,6 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( ##################### -class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): - """Compute radius-based neighbors for two sets of vectors. - - For each row-vector X[i] of the queries X, find all the indices j of - row-vectors in Y such that: - - dist(X[i], Y[j]) <= radius - - The distance function `dist` depends on the values of the `metric` - and `metric_kwargs` parameters. - - This class is not meant to be instanciated, one should only use - its :meth:`compute` classmethod which handles allocation and - deallocation consistently. - """ - - @classmethod - def compute( - cls, - X, - Y, - DTYPE_t radius, - str metric="euclidean", - chunk_size=None, - dict metric_kwargs=None, - str strategy=None, - bint return_distance=False, - bint sort_results=False, - ): - """Return the results of the reduction for the given arguments. - - Parameters - ---------- - X : ndarray or CSR matrix of shape (n_samples_X, n_features) - Input data. - - Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) - Input data. - - radius : float - The radius defining the neighborhood. - - metric : str, default='euclidean' - The distance metric to use. - For a list of available metrics, see the documentation of - :class:`~sklearn.metrics.DistanceMetric`. - - chunk_size : int, default=None, - The number of vectors per chunk. If None (default) looks-up in - scikit-learn configuration for `pairwise_dist_chunk_size`, - and use 256 if it is not set. - - metric_kwargs : dict, default=None - Keyword arguments to pass to specified metric function. - - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None - The chunking strategy defining which dataset parallelization are made on. - - For both strategies the computations happens with two nested loops, - respectively on chunks of X and chunks of Y. - Strategies differs on which loop (outer or inner) is made to run - in parallel with the Cython `prange` construct: - - - 'parallel_on_X' dispatches chunks of X uniformly on threads. - Each thread then iterates on all the chunks of Y. This strategy is - embarrassingly parallel and comes with no datastructures synchronisation. - - - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. - Each thread processes all the chunks of X in turn. This strategy is - a sequence of embarrassingly parallel subtasks (the inner loop on Y - chunks) with intermediate datastructures synchronisation at each - iteration of the sequential outer loop on X chunks. - - - 'auto' relies on a simple heuristic to choose between - 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, - 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` - is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity - for parallelism and is therefore more efficient despite the synchronization - step at each iteration of the outer loop on chunks of `X`. - - - None (default) looks-up in scikit-learn configuration for - `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. - - return_distance : boolean, default=False - Return distances between each X vector and its neighbors if set to True. - - sort_results : boolean, default=False - Sort results with respect to distances between each X vector and its - neighbors if set to True. - - Returns - ------- - If return_distance=False: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - If return_distance=True: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - neighbors_distances : ndarray of n_samples_X ndarray - Distances to the neighbors for each vector in X. - - Notes - ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private dtype-specialized implementation of - :class:`PairwiseDistancesRadiusNeighborhood`. - - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. - - This allows entirely decoupling the API entirely from the - implementation details whilst maintaining RAII. - """ - if X.dtype == Y.dtype == np.float64: - return PairwiseDistancesRadiusNeighborhood64.compute( - X=X, - Y=Y, - radius=radius, - metric=metric, - chunk_size=chunk_size, - metric_kwargs=metric_kwargs, - strategy=strategy, - sort_results=sort_results, - return_distance=return_distance, - ) - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): """64bit implementation of PairwiseDistancesArgKmin.""" From a8c7dfa8cde5f4891c338c2d6174d619a88f69db Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 15 Jun 2022 15:15:25 +0200 Subject: [PATCH 10/19] DOC Fix typo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> --- .../_pairwise_distances_reduction/_radius_neighborhood.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd index 29630a26d0d27..737e6888a8a55 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd @@ -33,7 +33,7 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( ##################### cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" + """64bit implementation of PairwiseDistancesRadiusNeighborhood .""" cdef: DTYPE_t radius From fd754d5409400e40e2e17e25a7b83079b6ec00e8 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 15 Jun 2022 16:21:01 +0200 Subject: [PATCH 11/19] DOC Move comment where appropriate --- .../_pairwise_distances_reduction/_dispatcher.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index e4fea34125b96..f4b87370c2b6d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -172,6 +172,11 @@ def compute( for the concrete implementation are therefore freed when this classmethod returns. """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various backend and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. if X.dtype == Y.dtype == np.float64: return PairwiseDistancesArgKmin64.compute( X=X, @@ -305,6 +310,11 @@ def compute( This allows entirely decoupling the API entirely from the implementation details whilst maintaining RAII. """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various backend and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. if X.dtype == Y.dtype == np.float64: return PairwiseDistancesRadiusNeighborhood64.compute( X=X, From 79c8188dde7028f551d8ff734f07e9967733b7a1 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 15 Jun 2022 17:12:45 +0200 Subject: [PATCH 12/19] MAINT Dispatch _sqeuclidean_row_norms Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction/__init__.py | 5 ++--- .../_pairwise_distances_reduction/_dispatcher.py | 10 ++++++++++ .../tests/test_pairwise_distances_reduction.py | 6 +++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 76c62c09e7287..89778ae9fc5e0 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -86,17 +86,16 @@ # Multiplication (see the docstring of :class:`GEMMTermComputer64` for details). -from ._base import _sqeuclidean_row_norms64 - from ._dispatcher import ( PairwiseDistancesReduction, PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood, + _sqeuclidean_row_norms, ) __all__ = [ "PairwiseDistancesReduction", "PairwiseDistancesArgKmin", "PairwiseDistancesRadiusNeighborhood", - "_sqeuclidean_row_norms64", + "_sqeuclidean_row_norms", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index f4b87370c2b6d..c0193b33dec2b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -4,12 +4,22 @@ from scipy.sparse import issparse from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING +from ._base import _sqeuclidean_row_norms64 from ._argkmin import PairwiseDistancesArgKmin64 from ._radius_neighborhood import PairwiseDistancesRadiusNeighborhood64 from ... import get_config +def _sqeuclidean_row_norms(X, num_threads): + """Compute the squared euclidean norm of the rows of X in parallel.""" + if X.dtype == np.float64: + return _sqeuclidean_row_norms64(X, num_threads) + raise ValueError( + f"Only 64bit float datasets are supported at this time, got: X.dtype={X.dtype}." + ) + + class PairwiseDistancesReduction: """Abstract base dispatcher for pairwise distance computation & reduction. diff --git a/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py index 6b63582c4e935..931790abd4425 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py @@ -12,7 +12,7 @@ PairwiseDistancesReduction, PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood, - _sqeuclidean_row_norms64, + _sqeuclidean_row_norms, ) from sklearn.metrics import euclidean_distances @@ -937,7 +937,7 @@ def test_pairwise_distances_radius_neighbors( @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("n_features", [5, 10, 100]) @pytest.mark.parametrize("num_threads", [1, 2, 8]) -def test_sqeuclidean_row_norms64( +def test_sqeuclidean_row_norms( global_random_seed, n_samples, n_features, @@ -949,6 +949,6 @@ def test_sqeuclidean_row_norms64( X = rng.rand(n_samples, n_features).astype(dtype) * spread sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2 - sq_row_norm = np.asarray(_sqeuclidean_row_norms64(X, num_threads=num_threads)) + sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads)) assert_allclose(sq_row_norm_reference, sq_row_norm) From 91420d6535a56e8b76383d0ff664a260446aa9cc Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 22 Jun 2022 15:23:21 +0200 Subject: [PATCH 13/19] Correctly handle read-only datasets Also reword comments in tests --- .../_gemm_term_computer.pyx | 4 +- .../test_pairwise_distances_reduction.py | 57 ++++++++++++++++++- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx index b4281d27bb2eb..77d752548bb5b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx @@ -28,8 +28,8 @@ cdef class GEMMTermComputer64: """ def __init__(self, - DTYPE_t[:, ::1] X, - DTYPE_t[:, ::1] Y, + const DTYPE_t[:, ::1] X, + const DTYPE_t[:, ::1] Y, ITYPE_t effective_n_threads, ITYPE_t chunks_n_threads, ITYPE_t dist_middle_terms_chunks_size, diff --git a/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py index 931790abd4425..1505dd8df2469 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py @@ -21,6 +21,7 @@ assert_array_equal, assert_allclose, assert_radius_neighborhood_results_equality, + create_memmap_backed_data, ) # Common supported metric between scipy.spatial.distance.cdist @@ -648,7 +649,7 @@ def test_chunk_size_agnosticism( n_features=100, dtype=np.float64, ): - # Results should not depend on the chunk size + # Results must not depend on the chunk size rng = np.random.RandomState(global_random_seed) spread = 100 X = rng.rand(n_samples, n_features).astype(dtype) * spread @@ -699,7 +700,7 @@ def test_n_threads_agnosticism( n_features=100, dtype=np.float64, ): - # Results should not depend on the number of threads + # Results must not depend on the number of threads rng = np.random.RandomState(global_random_seed) spread = 100 X = rng.rand(n_samples, n_features).astype(dtype) * spread @@ -934,6 +935,58 @@ def test_pairwise_distances_radius_neighbors( ) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], +) +@pytest.mark.parametrize("metric", ["manhattan", "euclidean"]) +def test_memmap_backed_data( + global_random_seed, + metric, + PairwiseDistancesReduction, + n_samples=512, + n_features=100, + dtype=np.float64, +): + # Results must not depend on the datasets writability + rng = np.random.RandomState(global_random_seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + # Create read only datasets + X_mm, Y_mm = create_memmap_backed_data([X, Y]) + + if PairwiseDistancesReduction is PairwiseDistancesArgKmin: + parameter = 10 + check_parameters = {} + else: + # Scaling the radius slightly with the numbers of dimensions + radius = 10 ** np.log(n_features) + parameter = radius + check_parameters = {"radius": radius} + + ref_dist, ref_indices = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + metric=metric, + return_distance=True, + ) + + dist_mm, indices_mm = PairwiseDistancesReduction.compute( + X_mm, + Y_mm, + parameter, + metric=metric, + return_distance=True, + ) + + ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( + ref_dist, dist_mm, ref_indices, indices_mm, **check_parameters + ) + + @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("n_features", [5, 10, 100]) @pytest.mark.parametrize("num_threads", [1, 2, 8]) From 8f791b6e87d733c47923882cfb16d80e8becfe40 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 22 Jun 2022 17:12:28 +0200 Subject: [PATCH 14/19] Apply review comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From: https://github.com/jjerphan/scikit-learn/pull/14#pullrequestreview-1004501769 Co-authored-by: Jérémie du Boisberranger --- .../_radius_neighborhood.pyx | 2 +- .../test_pairwise_distances_reduction.py | 20 ++++++++++++++++++- sklearn/neighbors/tests/test_neighbors.py | 4 +++- sklearn/utils/_testing.py | 19 ------------------ 4 files changed, 23 insertions(+), 22 deletions(-) rename sklearn/metrics/{_pairwise_distances_reduction => }/tests/test_pairwise_distances_reduction.py (98%) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx index d14fab614a3d7..db2c22e89d06d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx @@ -56,7 +56,7 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" + """64bit implementation of PairwiseDistancesRadiusNeighborhood.""" @classmethod def compute( diff --git a/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py similarity index 98% rename from sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py rename to sklearn/metrics/tests/test_pairwise_distances_reduction.py index 1505dd8df2469..d13498a346fb0 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -20,7 +20,6 @@ from sklearn.utils._testing import ( assert_array_equal, assert_allclose, - assert_radius_neighborhood_results_equality, create_memmap_backed_data, ) @@ -179,6 +178,25 @@ def assert_argkmin_results_quasi_equality( ), msg +def assert_radius_neighborhood_results_equality( + ref_dist, dist, ref_indices, indices, radius +): + # We get arrays of arrays and we need to check for individual pairs + for i in range(ref_dist.shape[0]): + assert (ref_dist[i] <= radius).all() + assert_array_equal( + ref_indices[i], + indices[i], + err_msg=f"Query vector #{i} has different neighbors' indices", + ) + assert_allclose( + ref_dist[i], + dist[i], + err_msg=f"Query vector #{i} has different neighbors' distances", + rtol=1e-7, + ) + + def assert_radius_neighborhood_results_quasi_equality( ref_dist, dist, diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index da548e4a9f046..faffa8bf85265 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -28,6 +28,9 @@ from sklearn.exceptions import NotFittedError from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS +from sklearn.metrics.tests.test_pairwise_distances_reduction import ( + assert_radius_neighborhood_results_equality, +) from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split from sklearn.neighbors import ( @@ -44,7 +47,6 @@ from sklearn.utils._testing import ( assert_allclose, assert_array_equal, - assert_radius_neighborhood_results_equality, ) from sklearn.utils._testing import ignore_warnings from sklearn.utils.validation import check_random_state diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 4158f7fa65c83..8a94b1f31abee 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -1151,22 +1151,3 @@ def transform(self, X, y=None): def fit_transform(self, X, y=None): return self.fit(X, y).transform(X, y) - - -def assert_radius_neighborhood_results_equality( - ref_dist, dist, ref_indices, indices, radius -): - # We get arrays of arrays and we need to check for individual pairs - for i in range(ref_dist.shape[0]): - assert (ref_dist[i] <= radius).all() - assert_array_equal( - ref_indices[i], - indices[i], - err_msg=f"Query vector #{i} has different neighbors' indices", - ) - assert_allclose( - ref_dist[i], - dist[i], - err_msg=f"Query vector #{i} has different neighbors' distances", - rtol=1e-7, - ) From 229758c24cd3adff83f0ecf6c40dfd23a3358585 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 23 Jun 2022 13:07:58 +0200 Subject: [PATCH 15/19] Apply review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- .../_pairwise_distances_reduction/__init__.py | 4 +- .../_dispatcher.py | 46 ++++++++++++++++++- .../_pairwise_distances_reduction/setup.py | 2 - .../tests/__init__.py | 0 .../test_pairwise_distances_reduction.py | 8 +++- 5 files changed, 52 insertions(+), 8 deletions(-) delete mode 100644 sklearn/metrics/_pairwise_distances_reduction/tests/__init__.py diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 89778ae9fc5e0..92dea6810110c 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -90,12 +90,12 @@ PairwiseDistancesReduction, PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood, - _sqeuclidean_row_norms, + sqeuclidean_row_norms, ) __all__ = [ "PairwiseDistancesReduction", "PairwiseDistancesArgKmin", "PairwiseDistancesRadiusNeighborhood", - "_sqeuclidean_row_norms", + "sqeuclidean_row_norms", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index c0193b33dec2b..853989aa3f504 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -1,3 +1,5 @@ +from abc import abstractmethod + import numpy as np from typing import List @@ -11,8 +13,22 @@ from ... import get_config -def _sqeuclidean_row_norms(X, num_threads): - """Compute the squared euclidean norm of the rows of X in parallel.""" +def sqeuclidean_row_norms(X, num_threads): + """Compute the squared euclidean norm of the rows of X in parallel. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_features) + Input data. Must be c-contiguous. + + num_threads : int + The number of OpenMP threads to use. + + Returns + ------- + sqeuclidean_row_norms : ndarray of shape (n_samples,) + Arrays containing the squared euclidean norm of each row of X. + """ if X.dtype == np.float64: return _sqeuclidean_row_norms64(X, num_threads) raise ValueError( @@ -72,6 +88,32 @@ def is_usable_for(cls, X, Y, metric) -> bool: and metric in cls.valid_metrics() ) + @classmethod + @abstractmethod + def compute( + cls, + X, + Y, + **kwargs, + ): + """Compute the reduction. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + **kwargs : additional parameters for the reduction + + Notes + ----- + This method is an abstract class method: it has to be implemented + for all subclasses. + """ + class PairwiseDistancesArgKmin(PairwiseDistancesReduction): """Compute the argkmin of row vectors of X on the ones of Y. diff --git a/sklearn/metrics/_pairwise_distances_reduction/setup.py b/sklearn/metrics/_pairwise_distances_reduction/setup.py index ddde0717835da..0d8c2c8ce33de 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/setup.py +++ b/sklearn/metrics/_pairwise_distances_reduction/setup.py @@ -29,8 +29,6 @@ def configuration(parent_package="", top_path=None): extra_compile_args=["-std=c++11"], ) - config.add_subpackage("tests") - return config diff --git a/sklearn/metrics/_pairwise_distances_reduction/tests/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index d13498a346fb0..e0d2beb820506 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -12,7 +12,7 @@ PairwiseDistancesReduction, PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood, - _sqeuclidean_row_norms, + sqeuclidean_row_norms, ) from sklearn.metrics import euclidean_distances @@ -1020,6 +1020,10 @@ def test_sqeuclidean_row_norms( X = rng.rand(n_samples, n_features).astype(dtype) * spread sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2 - sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads)) + sq_row_norm = np.asarray(sqeuclidean_row_norms(X, num_threads=num_threads)) assert_allclose(sq_row_norm_reference, sq_row_norm) + + with pytest.raises(ValueError): + X = np.asfortranarray(X) + sqeuclidean_row_norms(X, num_threads=num_threads) From 303aac6bb8a8eae1cd10d40f76a746ffe9609a1c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 23 Jun 2022 16:03:33 +0200 Subject: [PATCH 16/19] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/metrics/_pairwise_distances_reduction/__init__.py | 2 +- sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 92dea6810110c..d420060ca78df 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -42,7 +42,7 @@ # Legend: # # A ---⊳ B: A inherits from B -# A ---x B: A dispatches on B +# A ---x B: A dispatches to B # # # (base dispatcher) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 853989aa3f504..a79fde694a9ed 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -79,7 +79,7 @@ def is_usable_for(cls, X, Y, metric) -> bool: ------- True if the PairwiseDistancesReduction can be used, else False. """ - dtypes_validity = X.dtype == Y.dtype and Y.dtype == np.float64 + dtypes_validity = X.dtype == Y.dtype == np.float64 return ( get_config().get("enable_cython_pairwise_dist", True) and not issparse(X) From 94c76fbaa7f9169abba8cf12eb6c42c84bc2d550 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 23 Jun 2022 21:09:08 +0200 Subject: [PATCH 17/19] TST Remove useless globam random seed Co-authored-by: Thomas J. Fan --- sklearn/metrics/tests/test_pairwise_distances_reduction.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index e0d2beb820506..0b9c6e6aad196 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -959,7 +959,6 @@ def test_pairwise_distances_radius_neighbors( ) @pytest.mark.parametrize("metric", ["manhattan", "euclidean"]) def test_memmap_backed_data( - global_random_seed, metric, PairwiseDistancesReduction, n_samples=512, @@ -967,7 +966,7 @@ def test_memmap_backed_data( dtype=np.float64, ): # Results must not depend on the datasets writability - rng = np.random.RandomState(global_random_seed) + rng = np.random.RandomState(0) spread = 100 X = rng.rand(n_samples, n_features).astype(dtype) * spread Y = rng.rand(n_samples, n_features).astype(dtype) * spread From dc26da073615b309fd953045d52ddef9e6effa74 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 23 Jun 2022 21:10:11 +0200 Subject: [PATCH 18/19] CI Trigger some tests on supporting global random seed [all random seeds] test_chunk_size_agnosticism test_n_threads_agnosticism test_strategies_consistency test_pairwise_distances_argkmin test_pairwise_distances_radius_neighbors test_sqeuclidean_row_norms Co-authored-by: Thomas J. Fan From b089de367b50a0e4c905745d58e3eb21604faeeb Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 23 Jun 2022 20:13:12 -0400 Subject: [PATCH 19/19] CI Trigger [all random seeds] test_chunk_size_agnosticism test_n_threads_agnosticism test_strategies_consistency test_pairwise_distances_argkmin test_pairwise_distances_radius_neighbors test_sqeuclidean_row_norms