diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp index 2bbe9e53518b3..91936d6922ef4 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp @@ -6,7 +6,6 @@ cnp.import_array() {{for name_suffix in ['64', '32']}} from ._base cimport BaseDistancesReduction{{name_suffix}} -from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}} cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): """float{{name_suffix}} implementation of the ArgKmin.""" @@ -21,14 +20,4 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): DTYPE_t ** heaps_r_distances_chunks ITYPE_t ** heaps_indices_chunks - -cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): - """EuclideanDistance-specialisation of ArgKmin{{name_suffix}}.""" - cdef: - MiddleTermComputer{{name_suffix}} middle_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances - {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index b8afe5c3cd5f8..79f7c20c2153b 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -10,11 +10,10 @@ from ...utils._sorting cimport simultaneous_sort from ...utils._typedefs cimport ITYPE_t, DTYPE_t import numpy as np -import warnings from numbers import Integral -from scipy.sparse import issparse -from ...utils import check_array, check_scalar, _in_unstable_openblas_configuration +# TODO: reintroduce _in_unstable... warning +from ...utils import check_scalar from ...utils.fixes import threadpool_limits from ...utils._typedefs import ITYPE, DTYPE @@ -23,14 +22,7 @@ cnp.import_array() {{for name_suffix in ['64', '32']}} -from ._base cimport ( - BaseDistancesReduction{{name_suffix}}, - _sqeuclidean_row_norms{{name_suffix}}, -) - -from ._datasets_pair cimport DatasetsPair{{name_suffix}} - -from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}} +from ._base cimport BaseDistancesReduction{{name_suffix}} cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): @@ -61,36 +53,14 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): No instance should directly be created outside of this class method. """ - if ( - metric in ("euclidean", "sqeuclidean") - and not (issparse(X) ^ issparse(Y)) # "^" is the XOR operator - ): - # Specialized implementation of ArgKmin for the Euclidean distance - # for the dense-dense and sparse-sparse cases. - # This implementation computes the distances by chunk using - # a decomposition of the Squared Euclidean distance. - # This specialisation has an improved arithmetic intensity for both - # the dense and sparse settings, allowing in most case speed-ups of - # several orders of magnitude compared to the generic ArgKmin - # implementation. - # For more information see MiddleTermComputer. - use_squared_distances = metric == "sqeuclidean" - pda = EuclideanArgKmin{{name_suffix}}( - X=X, Y=Y, k=k, - use_squared_distances=use_squared_distances, - chunk_size=chunk_size, - strategy=strategy, - metric_kwargs=metric_kwargs, - ) - else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = ArgKmin{{name_suffix}}( - datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - ) + pda = ArgKmin{{name_suffix}}( + X, Y, + chunk_size=chunk_size, + strategy=strategy, + k=k, + metric=metric, + metric_kwargs=metric_kwargs, + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -104,15 +74,19 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): def __init__( self, - DatasetsPair{{name_suffix}} datasets_pair, + X, Y, chunk_size=None, strategy=None, ITYPE_t k=1, + metric="euclidean", + metric_kwargs=None, ): super().__init__( - datasets_pair=datasets_pair, + X, Y, chunk_size=chunk_size, strategy=strategy, + metric=metric, + metric_kwargs=metric_kwargs ) self.k = check_scalar(k, "k", Integral, min_val=1) @@ -166,10 +140,23 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): values=heaps_r_distances + i * self.k, indices=heaps_indices + i * self.k, size=self.k, - val=self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), + val=self.datasets_pair.surrogate_dist( + X_start, + Y_start, + i, + j, + n_samples_Y, + thread_num, + ), val_idx=Y_start + j, ) + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num + ) nogil: + self.datasets_pair._parallel_on_X_parallel_init(thread_num) + cdef void _parallel_on_X_init_chunk( self, ITYPE_t thread_num, @@ -180,6 +167,20 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): # thread's heaps pointer to the proper position on the main heaps. self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0] self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] + self.datasets_pair._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + self.datasets_pair._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) @final cdef void _parallel_on_X_prange_iter_finalize( @@ -224,6 +225,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): self.heaps_indices_chunks[thread_num] = malloc( heaps_size * sizeof(ITYPE_t) ) + self.datasets_pair._parallel_on_Y_init() cdef void _parallel_on_Y_parallel_init( self, @@ -235,6 +237,23 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): for idx in range(self.X_n_samples_chunk * self.k): self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX self.heaps_indices_chunks[thread_num][idx] = -1 + self.datasets_pair._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + self.datasets_pair._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + ) @final cdef void _parallel_on_Y_synchronize( @@ -289,13 +308,17 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): cdef: ITYPE_t i, j DTYPE_t[:, ::1] distances = self.argkmin_distances - for i in prange(self.n_samples_X, schedule='static', nogil=True, - num_threads=self.effective_n_threads): - for j in range(self.k): - distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against potential -0., causing nan production. - max(distances[i, j], 0.) - ) + bint need_to_compute_exact_dist = ( + self.datasets_pair.need_to_compute_exact_dist() + ) + if need_to_compute_exact_dist: + for i in prange(self.n_samples_X, schedule='static', nogil=True, + num_threads=self.effective_n_threads): + for j in range(self.k): + distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against potential -0., causing nan production. + max(distances[i, j], 0.) + ) def _finalize_results(self, bint return_distance=False): if return_distance: @@ -310,206 +333,4 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): return np.asarray(self.argkmin_indices) - -cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): - """EuclideanDistance-specialisation of ArgKmin{{name_suffix}}.""" - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - return (ArgKmin{{name_suffix}}.is_usable_for(X, Y, metric) and - not _in_unstable_openblas_configuration()) - - def __init__( - self, - X, - Y, - ITYPE_t k, - bint use_squared_distances=False, - chunk_size=None, - strategy=None, - metric_kwargs=None, - ): - if ( - isinstance(metric_kwargs, dict) and - (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (EuclideanArgKmin64) and will be ignored.", - UserWarning, - stacklevel=3, - ) - - super().__init__( - # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), - chunk_size=chunk_size, - strategy=strategy, - k=k, - ) - cdef: - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - - self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for( - X, - Y, - self.effective_n_threads, - self.chunks_n_threads, - dist_middle_terms_chunks_size, - n_features=X.shape[1], - chunk_size=self.chunk_size, - ) - - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = check_array( - metric_kwargs.pop("Y_norm_squared"), - ensure_2d=False, - input_name="Y_norm_squared", - dtype=np.float64, - ) - else: - self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( - Y, - self.effective_n_threads, - ) - - if metric_kwargs is not None and "X_norm_squared" in metric_kwargs: - self.X_norm_squared = check_array( - metric_kwargs.pop("X_norm_squared"), - ensure_2d=False, - input_name="X_norm_squared", - dtype=np.float64, - ) - else: - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms{{name_suffix}}( - X, - self.effective_n_threads, - ) - ) - - self.use_squared_distances = use_squared_distances - - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - ArgKmin{{name_suffix}}.compute_exact_distances(self) - - @final - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - ArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num) - self.middle_term_computer._parallel_on_X_parallel_init(thread_num) - - @final - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - ArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) - self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - ArgKmin{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num, - ) - - @final - cdef void _parallel_on_Y_init( - self, - ) nogil: - ArgKmin{{name_suffix}}._parallel_on_Y_init(self) - self.middle_term_computer._parallel_on_Y_init() - - @final - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - ArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) - self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - ArgKmin{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - @final - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t sqeuclidean_dist_i_j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start - DTYPE_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms( - X_start, X_end, Y_start, Y_end, thread_num - ) - DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num] - ITYPE_t * heaps_indices = self.heaps_indices_chunks[thread_num] - - # Pushing the distance and their associated indices on heaps - # which keep tracks of the argkmin. - for i in range(n_X): - for j in range(n_Y): - sqeuclidean_dist_i_j = ( - self.X_norm_squared[i + X_start] + - dist_middle_terms[i * n_Y + j] + - self.Y_norm_squared[j + Y_start] - ) - - # Catastrophic cancellation might cause -0. to be present, - # e.g. when computing d(x_i, y_i) when X is Y. - sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j) - - heap_push( - values=heaps_r_distances + i * self.k, - indices=heaps_indices + i * self.k, - size=self.k, - val=sqeuclidean_dist_i_j, - val_idx=j + Y_start, - ) - {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp index 1b2a8a31fb679..21454d9e79f86 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp @@ -26,6 +26,7 @@ from ...utils._typedefs cimport ITYPE_t, DTYPE_t import numpy as np +import warnings from scipy.sparse import issparse from numbers import Integral from sklearn import get_config @@ -131,7 +132,7 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}( ): if issparse(X): # TODO: remove this instruction which is a cast in the float32 case - # by moving squared row norms computations in MiddleTermComputer. + # by moving squared row norms computations in MiddleTermComputer. X_data = np.asarray(X.data, dtype=DTYPE) X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE) return _sqeuclidean_row_norms64_sparse(X_data, X_indptr, num_threads) @@ -151,9 +152,11 @@ cdef class BaseDistancesReduction{{name_suffix}}: def __init__( self, - DatasetsPair{{name_suffix}} datasets_pair, + X, Y, chunk_size=None, strategy=None, + metric="euclidean", + metric_kwargs=None, ): cdef: ITYPE_t X_n_full_chunks, Y_n_full_chunks @@ -165,9 +168,7 @@ cdef class BaseDistancesReduction{{name_suffix}}: self.effective_n_threads = _openmp_effective_n_threads() - self.datasets_pair = datasets_pair - - self.n_samples_X = datasets_pair.n_samples_X() + self.n_samples_X = X.shape[0] self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk @@ -178,7 +179,7 @@ cdef class BaseDistancesReduction{{name_suffix}}: else: self.X_n_samples_last_chunk = self.X_n_samples_chunk - self.n_samples_Y = datasets_pair.n_samples_Y() + self.n_samples_Y = Y.shape[0] self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk @@ -223,6 +224,52 @@ cdef class BaseDistancesReduction{{name_suffix}}: self.effective_n_threads, ) + if metric in ["euclidean", "sqeuclidean"]: + + X_norm_squared = ( + metric_kwargs.pop("X_norm_squared", None) + if metric_kwargs is not None + else None + ) + + Y_norm_squared = ( + metric_kwargs.pop("Y_norm_squared", None) + if metric_kwargs is not None + else None + ) + + euclidean_kwargs = dict( + effective_n_threads=self.effective_n_threads, + chunks_n_threads=self.chunks_n_threads, + dist_middle_terms_chunks_size=(self.Y_n_samples_chunk * self.X_n_samples_chunk), + X_norm_squared=X_norm_squared, + Y_norm_squared=Y_norm_squared, + Y_n_chunks=self.Y_n_chunks, + Y_n_samples_chunk=self.Y_n_samples_chunk, + Y_n_samples_last_chunk=self.Y_n_samples_last_chunk, + chunk_size=self.chunk_size, + ) + + if isinstance(metric_kwargs, dict) and metric_kwargs.keys(): + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " + f"usable for this case (euclidean) and will be ignored.", + UserWarning, + stacklevel=3, + ) + metric_kwargs = None + + else: + euclidean_kwargs = None + + self.datasets_pair = DatasetsPair{{name_suffix}}.get_for( + X, Y, + metric, + metric_kwargs, + euclidean_kwargs, + ) + + @final cdef void _parallel_on_X(self) nogil: """Perform computation and reduction in parallel on chunks of X. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp index e220f730e7529..3de1bbc546a43 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp @@ -13,12 +13,17 @@ implementation_specific_values = [ }} cimport numpy as cnp +from libcpp.vector cimport vector from ...utils._typedefs cimport DTYPE_t, ITYPE_t, SPARSE_INDEX_TYPE_t from ...metrics._dist_metrics cimport DistanceMetric, DistanceMetric32 {{for name_suffix, DistanceMetric, INPUT_DTYPE_t in implementation_specific_values}} +from ._middle_term_computer cimport ( + DenseDenseMiddleTermComputer{{name_suffix}}, + SparseSparseMiddleTermComputer{{name_suffix}}, +) cdef class DatasetsPair{{name_suffix}}: cdef: @@ -31,7 +36,53 @@ cdef class DatasetsPair{{name_suffix}}: cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil + cdef DTYPE_t surrogate_dist( + self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil + + cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + + cdef void _parallel_on_Y_init(self) nogil + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + + cdef bint need_to_compute_exact_dist(self) nogil cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): @@ -40,6 +91,26 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): const {{INPUT_DTYPE_t}}[:, ::1] Y +cdef class EuclideanDenseDenseDatasetsPair{{name_suffix}}(DenseDenseDatasetsPair{{name_suffix}}): + cdef: + ITYPE_t Y_n_chunks + ITYPE_t Y_n_samples_chunk + ITYPE_t Y_n_samples_last_chunk + ITYPE_t Y_n_samples_fixed_size + + ITYPE_t chunks_n_threads + + bint use_squared_distances + + vector[vector[DTYPE_t]] dist_middle_terms_chunks + ITYPE_t dist_middle_terms_chunks_size + + DenseDenseMiddleTermComputer{{name_suffix}} middle_term_computer + + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): cdef: const {{INPUT_DTYPE_t}}[:] X_data @@ -51,6 +122,26 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): const SPARSE_INDEX_TYPE_t[:] Y_indptr +cdef class EuclideanSparseSparseDatasetsPair{{name_suffix}}(SparseSparseDatasetsPair{{name_suffix}}): + cdef: + ITYPE_t Y_n_chunks + ITYPE_t Y_n_samples_chunk + ITYPE_t Y_n_samples_last_chunk + ITYPE_t Y_n_samples_fixed_size + + ITYPE_t chunks_n_threads + + bint use_squared_distances + + vector[vector[DTYPE_t]] dist_middle_terms_chunks + ITYPE_t dist_middle_terms_chunks_size + + SparseSparseMiddleTermComputer{{name_suffix}} middle_term_computer + + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): cdef: const {{INPUT_DTYPE_t}}[:] X_data diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 78857341f9c97..3d908f7439220 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -14,18 +14,35 @@ implementation_specific_values = [ }} import numpy as np cimport numpy as cnp +import warnings + +from libcpp.vector cimport vector from cython cimport final from ...utils._typedefs cimport DTYPE_t, ITYPE_t from ...metrics._dist_metrics cimport DistanceMetric +# TODO: change for `libcpp.algorithm.fill` once Cython 3 is used +# Introduction in Cython: +# +# https://github.com/cython/cython/blob/05059e2a9b89bf6738a7750b905057e5b1e3fe2e/Cython/Includes/libcpp/algorithm.pxd#L50 #noqa +cdef extern from "" namespace "std" nogil: + void fill[Iter, T](Iter first, Iter last, const T& value) except + #noqa + from scipy.sparse import issparse, csr_matrix +from ...utils import check_array from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE cnp.import_array() {{for name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} +from ._middle_term_computer cimport ( + DenseDenseMiddleTermComputer{{name_suffix}}, + SparseSparseMiddleTermComputer{{name_suffix}}, +) +from ._base cimport _sqeuclidean_row_norms{{name_suffix}} + cdef class DatasetsPair{{name_suffix}}: """Abstract class which wraps a pair of datasets (X, Y). @@ -62,6 +79,7 @@ cdef class DatasetsPair{{name_suffix}}: Y, str metric="euclidean", dict metric_kwargs=None, + dict euclidean_kwargs=None, ) -> DatasetsPair{{name_suffix}}: """Return the DatasetsPair implementation for the given arguments. @@ -91,13 +109,16 @@ cdef class DatasetsPair{{name_suffix}}: datasets_pair: DatasetsPair{{name_suffix}} The suited DatasetsPair{{name_suffix}} implementation. """ - # Y_norm_squared might be propagated down to DatasetsPairs - # via metrics_kwargs when the Euclidean specialisations - # can't be used. To prevent Y_norm_squared to be passed - # down to DistanceMetrics (whose constructors would raise - # a RuntimeError), we pop it here. - if metric_kwargs is not None: - metric_kwargs.pop("Y_norm_squared", None) + if ( + isinstance(metric_kwargs, dict) and + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) + ): + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " + f"usable for this case (EuclideanRadiusNeighbors64) and will be ignored.", + UserWarning, + stacklevel=3, + ) cdef: {{DistanceMetric}} distance_metric = {{DistanceMetric}}.get_metric( metric, @@ -111,11 +132,57 @@ cdef class DatasetsPair{{name_suffix}}: X_is_sparse = issparse(X) Y_is_sparse = issparse(Y) + use_euclidean_specialization = ( + metric in ("euclidean", "sqeuclidean") + and not (issparse(X) ^ issparse(Y)) # "^" is the XOR operator + ) + if use_euclidean_specialization: + use_squared_distances = metric == "sqeuclidean" + if not X_is_sparse and not Y_is_sparse: + if use_euclidean_specialization: + # Specialized implementation of {ArgKmin, RadiusNeighbors} + # for the Euclidean distance for the dense-dense case. + # This implementation computes the distances by chunk using + # a decomposition of the Squared Euclidean distance. + # This specialisation has an improved arithmetic intensity for both + # the dense and sparse settings, allowing in most case speed-ups of + # several orders of magnitude compared to the generic {ArgKmin, RadiusNeighbors} + # implementation. + # For more information see MiddleTermComputer. + return EuclideanDenseDenseDatasetsPair{{name_suffix}}( + X, Y, + distance_metric, + use_squared_distances, + **(euclidean_kwargs or {}), + ) + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) if X_is_sparse and Y_is_sparse: - return SparseSparseDatasetsPair{{name_suffix}}(X, Y, distance_metric) + if use_euclidean_specialization: + # Specialized implementation of {ArgKmin, RadiusNeighbors} + # for the Euclidean distance for the sparse-sparse case. + # This implementation computes the distances by chunk using + # a decomposition of the Squared Euclidean distance. + # This specialisation has an improved arithmetic intensity for both + # the dense and sparse settings, allowing in most case speed-ups of + # several orders of magnitude compared to the generic {ArgKmin, RadiusNeighbors} + # implementation. + # For more information see MiddleTermComputer. + return EuclideanSparseSparseDatasetsPair{{name_suffix}}( + X, Y, + distance_metric, + use_squared_distances, + **(euclidean_kwargs or {}), + ) + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + return SparseSparseDatasetsPair{{name_suffix}}( + X, Y, + distance_metric, + ) if X_is_sparse and not Y_is_sparse: return SparseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) @@ -148,7 +215,14 @@ cdef class DatasetsPair{{name_suffix}}: # TODO: add "with gil: raise" here when supporting Cython 3.0 return -999 - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + cdef DTYPE_t surrogate_dist(self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil: return self.dist(i, j) cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: @@ -157,7 +231,61 @@ cdef class DatasetsPair{{name_suffix}}: # TODO: add "with gil: raise" here when supporting Cython 3.0 return -1 -@final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + # This method must be overwritten in the Euclidean specialization only. + return + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + # This method must be overwritten in the Euclidean specialization only. + return + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + # This method must be overwritten in the Euclidean specialization only. + return + + cdef void _parallel_on_Y_init(self) nogil: + # This method must be overwritten in the Euclidean specialization only. + return + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + # This method must be overwritten in the Euclidean specialization only. + return + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + # This method must be overwritten in the Euclidean specialization only. + return + + cdef bint need_to_compute_exact_dist(self) nogil: + # This method must be overwritten in the Euclidean specialization only. + return True + cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): """Compute distances between row vectors of two arrays. @@ -194,8 +322,20 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): return self.Y.shape[0] @final - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: - return self.distance_metric.rdist(&self.X[i, 0], &self.Y[j, 0], self.n_features) + cdef DTYPE_t surrogate_dist( + self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil: + return self.distance_metric.rdist( + &self.X[X_start+i, 0], + &self.Y[Y_start+j, 0], + self.n_features, + ) @final cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: @@ -203,6 +343,240 @@ cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): @final +cdef class EuclideanDenseDenseDatasetsPair{{name_suffix}}(DenseDenseDatasetsPair{{name_suffix}}): + + def __init__( + self, + X, + Y, + {{DistanceMetric}} distance_metric, + bint use_squared_distances, + ITYPE_t Y_n_chunks, + ITYPE_t Y_n_samples_chunk, + ITYPE_t Y_n_samples_last_chunk, + ITYPE_t effective_n_threads, + ITYPE_t dist_middle_terms_chunks_size, + ITYPE_t chunks_n_threads, + **euclidean_kwargs, + ): + super().__init__(X, Y, distance_metric) + + # Used to compute the surrogate distance + self.Y_n_chunks = Y_n_chunks + self.Y_n_samples_chunk = Y_n_samples_chunk + self.Y_n_samples_last_chunk = Y_n_samples_last_chunk + + self.chunks_n_threads = chunks_n_threads + + # The number of samples belonging to the all-but-last chunks of fixed width. + self.Y_n_samples_fixed_size = (Y_n_chunks - 1) * Y_n_samples_chunk + + self.use_squared_distances = use_squared_distances + + self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](effective_n_threads) + self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size + + Y_norm_squared = (euclidean_kwargs or {}).pop("Y_norm_squared") + if Y_norm_squared is not None: + self.Y_norm_squared = check_array( + Y_norm_squared, + ensure_2d=False, + input_name="Y_norm_squared", + dtype=np.float64, + ) + else: + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( + Y, + effective_n_threads, + ) + + X_norm_squared = (euclidean_kwargs or {}).pop("X_norm_squared") + if X_norm_squared is not None: + self.X_norm_squared = check_array( + X_norm_squared, + ensure_2d=False, + input_name="X_norm_squared", + dtype=np.float64, + ) + else: + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms{{name_suffix}}( + X, + effective_n_threads, + ) + ) + + self.middle_term_computer = DenseDenseMiddleTermComputer{{name_suffix}}( + self.X, + self.Y, + n_features=X.shape[1], + effective_n_threads=effective_n_threads, + **(euclidean_kwargs or {}), + ) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_X_parallel_init( + self, thread_num + ) + self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + #self.middle_term_computer._parallel_on_X_parallel_init(thread_num) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_X_init_chunk( + self, thread_num, X_start, X_end + ) + self.middle_term_computer._parallel_on_X_init_chunk( + thread_num, + X_start, + X_end, + ) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j, k, l + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() + + DatasetsPair{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, X_start, X_end, Y_start, Y_end, thread_num + ) + self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + ) + self.middle_term_computer._compute_dist_middle_terms( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + dist_middle_terms, + ) + for i in range(n_X): + k = i * n_Y + for j in range(n_Y): + l = k + j + # Catastrophic cancellation might cause -0. to be present, + # e.g. when computing d(x_i, y_i) when X is Y. + dist_middle_terms[l] = max(0, + dist_middle_terms[l] + + self.X_norm_squared[X_start + i] + + self.Y_norm_squared[Y_start + j] + ) + + @final + cdef void _parallel_on_Y_init(self) nogil: + DatasetsPair{{name_suffix}}._parallel_on_Y_init(self) + for thread_num in range(self.chunks_n_threads): + self.dist_middle_terms_chunks[thread_num].resize( + self.dist_middle_terms_chunks_size + ) + self.middle_term_computer._parallel_on_Y_init() + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_Y_parallel_init( + self, thread_num, X_start, X_end + ) + self.middle_term_computer._parallel_on_Y_parallel_init( + thread_num, + X_start, + X_end, + ) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j, k, l + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() + DatasetsPair{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, X_start, X_end, Y_start, Y_end, thread_num + ) + self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + ) + self.middle_term_computer._compute_dist_middle_terms( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + dist_middle_terms, + ) + for i in range(n_X): + k = i * n_Y + for j in range(n_Y): + l = k + j + # Catastrophic cancellation might cause -0. to be present, + # e.g. when computing d(x_i, y_i) when X is Y. + dist_middle_terms[l] = max(0, + dist_middle_terms[l] + + self.X_norm_squared[X_start + i] + + self.Y_norm_squared[Y_start + j] + ) + + @final + cdef inline DTYPE_t surrogate_dist( + self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil: + return self.dist_middle_terms_chunks[thread_num][i * n_Y + j] + + @final + cdef bint need_to_compute_exact_dist(self) nogil: + # At that point `self.argkmin_distances` stores surrogate distances, + # which are in this case square euclidean distances. + # We only need to convert those values for the `metric='euclidean'` case. + return not self.use_squared_distances + + cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): """Compute distances between vectors of two CSR matrices. @@ -234,7 +608,17 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): return self.Y_indptr.shape[0] - 1 @final - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + cdef DTYPE_t surrogate_dist( + self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num + ) nogil: + i += X_start + j += Y_start return self.distance_metric.rdist_csr( x1_data=&self.X_data[0], x1_indices=self.X_indices, @@ -262,6 +646,236 @@ cdef class SparseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): ) +@final +cdef class EuclideanSparseSparseDatasetsPair{{name_suffix}}(SparseSparseDatasetsPair{{name_suffix}}): + + def __init__( + self, + X, + Y, + {{DistanceMetric}} distance_metric, + bint use_squared_distances, + ITYPE_t Y_n_chunks, + ITYPE_t Y_n_samples_chunk, + ITYPE_t Y_n_samples_last_chunk, + ITYPE_t effective_n_threads, + ITYPE_t dist_middle_terms_chunks_size, + ITYPE_t chunks_n_threads, + **euclidean_kwargs, + ): + super().__init__(X, Y, distance_metric) + + # Used to compute the surrogate distance. + self.Y_n_chunks = Y_n_chunks + self.Y_n_samples_chunk = Y_n_samples_chunk + self.Y_n_samples_last_chunk = Y_n_samples_last_chunk + + self.chunks_n_threads = chunks_n_threads + + # The number of samples belonging to the all-but-last chunks of fixed width. + self.Y_n_samples_fixed_size = (Y_n_chunks - 1) * Y_n_samples_chunk + + self.use_squared_distances = use_squared_distances + + self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](effective_n_threads) + self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size + + euclidean_kwargs = euclidean_kwargs or {} + Y_norm_squared = euclidean_kwargs.pop("Y_norm_squared", None) + + if Y_norm_squared is not None: + self.Y_norm_squared = check_array( + Y_norm_squared, + ensure_2d=False, + input_name="Y_norm_squared", + dtype=np.float64, + ) + else: + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( + Y, + effective_n_threads, + ) + + X_norm_squared = euclidean_kwargs.pop("X_norm_squared", None) + + if X_norm_squared is not None: + self.X_norm_squared = check_array( + X_norm_squared, + ensure_2d=False, + input_name="X_norm_squared", + dtype=np.float64, + ) + else: + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms{{name_suffix}}( + X, + effective_n_threads, + ) + ) + + self.middle_term_computer = SparseSparseMiddleTermComputer{{name_suffix}}( + self.X_data, + self.X_indices, + self.X_indptr, + self.Y_data, + self.Y_indices, + self.Y_indptr, + n_features=X.shape[1], + effective_n_threads=effective_n_threads, + **(euclidean_kwargs), + ) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_X_parallel_init( + self, thread_num + ) + #self.middle_term_computer._parallel_on_X_parallel_init(thread_num) + self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_X_init_chunk( + self, thread_num, X_start, X_end + ) + self.middle_term_computer._parallel_on_X_init_chunk( + thread_num, + X_start, + X_end, + ) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, X_start, X_end, Y_start, Y_end, thread_num + ) + fill( + self.dist_middle_terms_chunks[thread_num].begin(), + self.dist_middle_terms_chunks[thread_num].end(), + 0.0, + ) + self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + ) + self.middle_term_computer._compute_dist_middle_terms( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + self.dist_middle_terms_chunks[thread_num].data(), + ) + + @final + cdef void _parallel_on_Y_init(self) nogil: + DatasetsPair{{name_suffix}}._parallel_on_Y_init(self) + for thread_num in range(self.chunks_n_threads): + self.dist_middle_terms_chunks[thread_num].resize( + self.dist_middle_terms_chunks_size + ) + self.middle_term_computer._parallel_on_Y_init() + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_Y_parallel_init( + self, thread_num, X_start, X_end + ) + self.middle_term_computer._parallel_on_Y_parallel_init( + thread_num, + X_start, + X_end, + ) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + DatasetsPair{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, X_start, X_end, Y_start, Y_end, thread_num + ) + fill( + self.dist_middle_terms_chunks[thread_num].begin(), + self.dist_middle_terms_chunks[thread_num].end(), + 0.0, + ) + self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + ) + self.middle_term_computer._compute_dist_middle_terms( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + self.dist_middle_terms_chunks[thread_num].data(), + ) + + @final + cdef DTYPE_t surrogate_dist( + self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil: + # Y_chunk_size is constant except for the last chunk that can be smaller. + # cdef: + # DTYPE_t dist_middle_term = ( + # self.middle_term_computer.dist_middle_terms_chunks[thread_num][i * n_Y + j] + # ) + # Catastrophic cancellation might cause -0. to be present, + # e.g. when computing d(x_i, y_i) when X is Y. + return max(0., + self.X_norm_squared[i + X_start] + + self.dist_middle_terms_chunks[thread_num][i * n_Y + j] + + self.Y_norm_squared[j + Y_start] + ) + + @final + cdef bint need_to_compute_exact_dist(self) nogil: + # At that point `self.argkmin_distances` stores surrogate distances, + # which are in this case square euclidean distances. + # We only need to convert those values for the `metric='euclidean'` case. + return not self.use_squared_distances + + @final cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): """Compute distances between vectors of a CSR matrix and a dense array. @@ -327,7 +941,16 @@ cdef class SparseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): return self.n_Y @final - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + cdef DTYPE_t surrogate_dist(self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil: + i += X_start + j += Y_start return self.distance_metric.rdist_csr( x1_data=&self.X_data[0], x1_indices=self.X_indices, @@ -393,9 +1016,17 @@ cdef class DenseSparseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): return self.datasets_pair.n_samples_X() @final - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + cdef DTYPE_t surrogate_dist( + self, + ITYPE_t X_start, + ITYPE_t Y_start, + ITYPE_t i, + ITYPE_t j, + ITYPE_t n_Y, + ITYPE_t thread_num, + ) nogil: # Swapping arguments on the same interface - return self.datasets_pair.surrogate_dist(j, i) + return self.datasets_pair.surrogate_dist(Y_start, X_start, j, i, n_Y, thread_num) @final cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp index e6ef5de2727b5..502b3d12c40fe 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp @@ -41,13 +41,11 @@ cdef void _middle_term_sparse_sparse_64( cdef class MiddleTermComputer{{name_suffix}}: cdef: ITYPE_t effective_n_threads - ITYPE_t chunks_n_threads - ITYPE_t dist_middle_terms_chunks_size ITYPE_t n_features ITYPE_t chunk_size # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM - vector[vector[DTYPE_t]] dist_middle_terms_chunks + # vector[vector[DTYPE_t]] dist_middle_terms_chunks cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, @@ -85,13 +83,14 @@ cdef class MiddleTermComputer{{name_suffix}}: ITYPE_t thread_num ) nogil - cdef DTYPE_t * _compute_dist_middle_terms( + cdef void _compute_dist_middle_terms( self, ITYPE_t X_start, ITYPE_t X_end, ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, + DTYPE_t *dist_middle_term, ) nogil @@ -138,23 +137,24 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_ ITYPE_t thread_num ) nogil - cdef DTYPE_t * _compute_dist_middle_terms( + cdef void _compute_dist_middle_terms( self, ITYPE_t X_start, ITYPE_t X_end, ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, + DTYPE_t *dist_middle_term, ) nogil cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}): cdef: - const DTYPE_t[:] X_data + const {{INPUT_DTYPE_t}}[:] X_data const SPARSE_INDEX_TYPE_t[:] X_indices const SPARSE_INDEX_TYPE_t[:] X_indptr - const DTYPE_t[:] Y_data + const {{INPUT_DTYPE_t}}[:] Y_data const SPARSE_INDEX_TYPE_t[:] Y_indices const SPARSE_INDEX_TYPE_t[:] Y_indptr @@ -176,13 +176,14 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam ITYPE_t thread_num ) nogil - cdef DTYPE_t * _compute_dist_middle_terms( + cdef void _compute_dist_middle_terms( self, ITYPE_t X_start, ITYPE_t X_end, ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, + DTYPE_t *dist_middle_term, ) nogil diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index 6b48ed519267b..aa4f416e9badc 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -31,13 +31,16 @@ from ...utils._typedefs cimport DTYPE_t, ITYPE_t, SPARSE_INDEX_TYPE_t # Introduction in Cython: # # https://github.com/cython/cython/blob/05059e2a9b89bf6738a7750b905057e5b1e3fe2e/Cython/Includes/libcpp/algorithm.pxd#L50 #noqa -cdef extern from "" namespace "std" nogil: - void fill[Iter, T](Iter first, Iter last, const T& value) except + #noqa +# cdef extern from "" namespace "std" nogil: +# void fill[Iter, T](Iter first, Iter last, const T& value) except + #noqa import numpy as np from scipy.sparse import issparse, csr_matrix from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE + +{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + # TODO: If possible optimize this routine to efficiently treat cases where # `n_samples_X << n_samples_Y` met in practise when X_test consists of a # few samples, and thus when there's a single chunk of X whose number of @@ -46,13 +49,13 @@ from ...utils._typedefs import DTYPE, SPARSE_INDEX_TYPE # TODO: compare this routine with the similar ones in SciPy, especially # `csr_matmat` which might implement a better algorithm. # See: https://github.com/scipy/scipy/blob/e58292e066ba2cb2f3d1e0563ca9314ff1f4f311/scipy/sparse/sparsetools/csr.h#L603-L669 # noqa -cdef void _middle_term_sparse_sparse_64( - const DTYPE_t[:] X_data, +cdef void _middle_term_sparse_sparse_{{name_suffix}}( + const {{INPUT_DTYPE_t}}[:] X_data, const SPARSE_INDEX_TYPE_t[:] X_indices, const SPARSE_INDEX_TYPE_t[:] X_indptr, ITYPE_t X_start, ITYPE_t X_end, - const DTYPE_t[:] Y_data, + const {{INPUT_DTYPE_t}}[:] Y_data, const SPARSE_INDEX_TYPE_t[:] Y_indices, const SPARSE_INDEX_TYPE_t[:] Y_indptr, ITYPE_t Y_start, @@ -79,9 +82,6 @@ cdef void _middle_term_sparse_sparse_64( D[k] += -2 * X_data[X_i_ptr] * Y_data[Y_j_ptr] -{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} - - cdef class MiddleTermComputer{{name_suffix}}: """Helper class to compute a Euclidean distance matrix in chunks. @@ -100,86 +100,17 @@ cdef class MiddleTermComputer{{name_suffix}}: the middle term, i.e. `- 2 X_c_i.Y_c_j^T`. """ - @classmethod - def get_for( - cls, - X, - Y, - effective_n_threads, - chunks_n_threads, - dist_middle_terms_chunks_size, - n_features, - chunk_size, - ) -> MiddleTermComputer{{name_suffix}}: - """Return the DatasetsPair implementation for the given arguments. - - Parameters - ---------- - X : ndarray or CSR sparse matrix of shape (n_samples_X, n_features) - Input data. - If provided as a ndarray, it must be C-contiguous. - - Y : ndarray or CSR sparse matrix of shape (n_samples_Y, n_features) - Input data. - If provided as a ndarray, it must be C-contiguous. - - Returns - ------- - middle_term_computer: MiddleTermComputer{{name_suffix}} - The suited MiddleTermComputer{{name_suffix}} implementation. - """ - X_is_sparse = issparse(X) - Y_is_sparse = issparse(Y) - - if not X_is_sparse and not Y_is_sparse: - return DenseDenseMiddleTermComputer{{name_suffix}}( - X, - Y, - effective_n_threads, - chunks_n_threads, - dist_middle_terms_chunks_size, - n_features, - chunk_size, - ) - if X_is_sparse and Y_is_sparse: - return SparseSparseMiddleTermComputer{{name_suffix}}( - X, - Y, - effective_n_threads, - chunks_n_threads, - dist_middle_terms_chunks_size, - n_features, - chunk_size, - ) - - raise NotImplementedError( - "X and Y must be both CSR sparse matrices or both numpy arrays." - ) - - - @classmethod - def unpack_csr_matrix(cls, X: csr_matrix): - """Ensure that the CSR matrix is indexed with SPARSE_INDEX_TYPE.""" - X_data = np.asarray(X.data, dtype=DTYPE) - X_indices = np.asarray(X.indices, dtype=SPARSE_INDEX_TYPE) - X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE) - return X_data, X_indices, X_indptr - def __init__( self, ITYPE_t effective_n_threads, - ITYPE_t chunks_n_threads, - ITYPE_t dist_middle_terms_chunks_size, ITYPE_t n_features, ITYPE_t chunk_size, ): self.effective_n_threads = effective_n_threads - self.chunks_n_threads = chunks_n_threads - self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size self.n_features = n_features self.chunk_size = chunk_size - self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) + #self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, @@ -192,7 +123,8 @@ cdef class MiddleTermComputer{{name_suffix}}: return cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: - self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + # self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + return cdef void _parallel_on_X_init_chunk( self, @@ -203,10 +135,7 @@ cdef class MiddleTermComputer{{name_suffix}}: return cdef void _parallel_on_Y_init(self) nogil: - for thread_num in range(self.chunks_n_threads): - self.dist_middle_terms_chunks[thread_num].resize( - self.dist_middle_terms_chunks_size - ) + return cdef void _parallel_on_Y_parallel_init( self, @@ -226,15 +155,16 @@ cdef class MiddleTermComputer{{name_suffix}}: ) nogil: return - cdef DTYPE_t * _compute_dist_middle_terms( + cdef void _compute_dist_middle_terms( self, ITYPE_t X_start, ITYPE_t X_end, ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, + DTYPE_t *dist_middle_terms_chunks_size, ) nogil: - return NULL + return cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}): @@ -252,15 +182,11 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_ const {{INPUT_DTYPE_t}}[:, ::1] X, const {{INPUT_DTYPE_t}}[:, ::1] Y, ITYPE_t effective_n_threads, - ITYPE_t chunks_n_threads, - ITYPE_t dist_middle_terms_chunks_size, ITYPE_t n_features, ITYPE_t chunk_size, ): super().__init__( effective_n_threads, - chunks_n_threads, - dist_middle_terms_chunks_size, n_features, chunk_size, ) @@ -359,16 +285,17 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_ return {{endif}} - cdef DTYPE_t * _compute_dist_middle_terms( + cdef void _compute_dist_middle_terms( self, ITYPE_t X_start, ITYPE_t X_end, ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, + DTYPE_t *dist_middle_terms, ) nogil: cdef: - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() + #DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() # Careful: LDA, LDB and LDC are given for F-ordered arrays # in BLAS documentations, for instance: @@ -400,8 +327,6 @@ cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_ # dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T` _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) - return dist_middle_terms - cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}): """Middle term of the Euclidean distance between two chunked CSR matrices. @@ -417,23 +342,23 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam def __init__( self, - X, - Y, + const {{INPUT_DTYPE_t}}[:] X_data, + const SPARSE_INDEX_TYPE_t[:] X_indices, + const SPARSE_INDEX_TYPE_t[:] X_indptr, + const {{INPUT_DTYPE_t}}[:] Y_data, + const SPARSE_INDEX_TYPE_t[:] Y_indices, + const SPARSE_INDEX_TYPE_t[:] Y_indptr, ITYPE_t effective_n_threads, - ITYPE_t chunks_n_threads, - ITYPE_t dist_middle_terms_chunks_size, ITYPE_t n_features, ITYPE_t chunk_size, ): super().__init__( effective_n_threads, - chunks_n_threads, - dist_middle_terms_chunks_size, n_features, chunk_size, ) - self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X) - self.Y_data, self.Y_indices, self.Y_indptr = self.unpack_csr_matrix(Y) + self.X_data, self.X_indices, self.X_indptr = X_data, X_indices, X_indptr + self.Y_data, self.Y_indices, self.Y_indptr = Y_data, Y_indices, Y_indptr cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, @@ -444,11 +369,12 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam ITYPE_t thread_num, ) nogil: # Flush the thread dist_middle_terms_chunks to 0.0 - fill( - self.dist_middle_terms_chunks[thread_num].begin(), - self.dist_middle_terms_chunks[thread_num].end(), - 0.0, - ) + #fill( + # self.dist_middle_terms_chunks[thread_num].begin(), + # self.dist_middle_terms_chunks[thread_num].end(), + # 0.0, + #) + return cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( self, @@ -459,26 +385,28 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam ITYPE_t thread_num, ) nogil: # Flush the thread dist_middle_terms_chunks to 0.0 - fill( - self.dist_middle_terms_chunks[thread_num].begin(), - self.dist_middle_terms_chunks[thread_num].end(), - 0.0, - ) + #fill( + # self.dist_middle_terms_chunks[thread_num].begin(), + # self.dist_middle_terms_chunks[thread_num].end(), + # 0.0, + #) + return - cdef DTYPE_t * _compute_dist_middle_terms( + cdef void _compute_dist_middle_terms( self, ITYPE_t X_start, ITYPE_t X_end, ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, + DTYPE_t *dist_middle_terms, ) nogil: - cdef: - DTYPE_t *dist_middle_terms = ( - self.dist_middle_terms_chunks[thread_num].data() - ) + #cdef: + # DTYPE_t *dist_middle_terms = ( + # self.dist_middle_terms_chunks[thread_num].data() + # ) - _middle_term_sparse_sparse_64( + _middle_term_sparse_sparse_{{name_suffix}}( self.X_data, self.X_indices, self.X_indptr, @@ -492,7 +420,5 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam dist_middle_terms, ) - return dist_middle_terms - {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd.tp index b6e4508468d2b..1c6a151c454f5 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd.tp @@ -29,7 +29,6 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( {{for name_suffix in ['64', '32']}} from ._base cimport BaseDistancesReduction{{name_suffix}} -from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}} cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): """float{{name_suffix}} implementation of the RadiusNeighbors.""" @@ -77,14 +76,4 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) ITYPE_t num_threads, ) nogil - -cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}}): - """EuclideanDistance-specialisation of RadiusNeighbors{{name_suffix}}.""" - cdef: - MiddleTermComputer{{name_suffix}} middle_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances - {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index b3f20cac3ea08..4920b6b1fa6d8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -1,6 +1,5 @@ cimport numpy as cnp import numpy as np -import warnings from libcpp.memory cimport shared_ptr, make_shared from libcpp.vector cimport vector @@ -13,8 +12,8 @@ from ...utils._typedefs cimport ITYPE_t, DTYPE_t from ...utils._vector_sentinel cimport vector_to_nd_array from numbers import Real -from scipy.sparse import issparse -from ...utils import check_array, check_scalar, _in_unstable_openblas_configuration +# TODO: reintroduce _in_unstable... warnings +from ...utils import check_scalar from ...utils.fixes import threadpool_limits cnp.import_array() @@ -43,15 +42,7 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( ##################### {{for name_suffix in ['64', '32']}} -from ._base cimport ( - BaseDistancesReduction{{name_suffix}}, - _sqeuclidean_row_norms{{name_suffix}} -) - -from ._datasets_pair cimport DatasetsPair{{name_suffix}} - -from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}} - +from ._base cimport BaseDistancesReduction{{name_suffix}} cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): """float{{name_suffix}} implementation of the RadiusNeighbors.""" @@ -82,38 +73,15 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) No instance should directly be created outside of this class method. """ - if ( - metric in ("euclidean", "sqeuclidean") - and not (issparse(X) ^ issparse(Y)) # "^" is XOR - ): - # Specialized implementation of RadiusNeighbors for the Euclidean - # distance for the dense-dense and sparse-sparse cases. - # This implementation computes the distances by chunk using - # a decomposition of the Squared Euclidean distance. - # This specialisation has an improved arithmetic intensity for both - # the dense and sparse settings, allowing in most case speed-ups of - # several orders of magnitude compared to the generic RadiusNeighbors - # implementation. - # For more information see MiddleTermComputer. - use_squared_distances = metric == "sqeuclidean" - pda = EuclideanRadiusNeighbors{{name_suffix}}( - X=X, Y=Y, radius=radius, - use_squared_distances=use_squared_distances, - chunk_size=chunk_size, - strategy=strategy, - sort_results=sort_results, - metric_kwargs=metric_kwargs, - ) - else: - # Fall back on a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = RadiusNeighbors{{name_suffix}}( - datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), - radius=radius, - chunk_size=chunk_size, - strategy=strategy, - sort_results=sort_results, - ) + pda = RadiusNeighbors{{name_suffix}}( + X, Y, + radius=radius, + chunk_size=chunk_size, + strategy=strategy, + sort_results=sort_results, + metric=metric, + metric_kwargs=metric_kwargs, + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -128,16 +96,20 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) def __init__( self, - DatasetsPair{{name_suffix}} datasets_pair, + X, Y, DTYPE_t radius, chunk_size=None, strategy=None, sort_results=False, + metric="euclidean", + metric_kwargs=None, ): super().__init__( - datasets_pair=datasets_pair, + X, Y, chunk_size=chunk_size, strategy=strategy, + metric=metric, + metric_kwargs=metric_kwargs, ) self.radius = check_scalar(radius, "radius", Real, min_val=0) @@ -176,10 +148,11 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) cdef: ITYPE_t i, j DTYPE_t r_dist_i_j + ITYPE_t n_Y = Y_end - Y_start for i in range(X_start, X_end): for j in range(Y_start, Y_end): - r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) + r_dist_i_j = self.datasets_pair.surrogate_dist(X_start, Y_start, i, j, n_Y, thread_num) if r_dist_i_j <= self.r_radius: deref(self.neigh_distances_chunks[thread_num])[i].push_back(r_dist_i_j) deref(self.neigh_indices_chunks[thread_num])[i].push_back(j) @@ -196,6 +169,14 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) return coerce_vectors_to_nd_arrays(self.neigh_indices) + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num + ) nogil: + self.datasets_pair._parallel_on_X_parallel_init(thread_num) + + @final cdef void _parallel_on_X_init_chunk( self, ITYPE_t thread_num, @@ -207,6 +188,20 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) # thread vectors' pointers to the main vectors'. self.neigh_distances_chunks[thread_num] = self.neigh_distances self.neigh_indices_chunks[thread_num] = self.neigh_indices + self.datasets_pair._parallel_on_X_init_chunk(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + self.datasets_pair._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num, + ) @final cdef void _parallel_on_X_prange_iter_finalize( @@ -227,6 +222,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) deref(self.neigh_indices)[idx].size() ) + @final cdef void _parallel_on_Y_init( self, ) nogil: @@ -238,6 +234,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) for thread_num in range(self.chunks_n_threads): self.neigh_distances_chunks[thread_num] = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) self.neigh_indices_chunks[thread_num] = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) + self.datasets_pair._parallel_on_Y_init() @final cdef void _merge_vectors( @@ -272,6 +269,33 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) ) last_element_idx += deref(self.neigh_distances_chunks[thread_num])[idx].size() + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + self.datasets_pair._parallel_on_Y_parallel_init(thread_num, X_start, X_end) + + @final + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + self.datasets_pair._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + X_start, + X_end, + Y_start, + Y_end, + thread_num, + ) + + @final cdef void _parallel_on_Y_finalize( self, ) nogil: @@ -299,222 +323,22 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) ) return - cdef void compute_exact_distances(self) nogil: """Convert rank-preserving distances to pairwise distances in parallel.""" cdef: ITYPE_t i, j - - for i in prange(self.n_samples_X, nogil=True, schedule='static', - num_threads=self.effective_n_threads): - for j in range(deref(self.neigh_indices)[i].size()): - deref(self.neigh_distances)[i][j] = ( - self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against potential -0., causing nan production. - max(deref(self.neigh_distances)[i][j], 0.) - ) - ) - - -cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}}): - """EuclideanDistance-specialisation of RadiusNeighbors{{name_suffix}}.""" - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - return (RadiusNeighbors{{name_suffix}}.is_usable_for(X, Y, metric) - and not _in_unstable_openblas_configuration()) - - def __init__( - self, - X, - Y, - DTYPE_t radius, - bint use_squared_distances=False, - chunk_size=None, - strategy=None, - sort_results=False, - metric_kwargs=None, - ): - if ( - isinstance(metric_kwargs, dict) and - (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) - ): - warnings.warn( - f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " - f"usable for this case (EuclideanRadiusNeighbors64) and will be ignored.", - UserWarning, - stacklevel=3, - ) - - super().__init__( - # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), - radius=radius, - chunk_size=chunk_size, - strategy=strategy, - sort_results=sort_results, - ) - cdef: - ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - - self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for( - X, - Y, - self.effective_n_threads, - self.chunks_n_threads, - dist_middle_terms_chunks_size, - n_features=X.shape[1], - chunk_size=self.chunk_size, - ) - - if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = check_array( - metric_kwargs.pop("Y_norm_squared"), - ensure_2d=False, - input_name="Y_norm_squared", - dtype=np.float64, - ) - else: - self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( - Y, - self.effective_n_threads, - ) - - if metric_kwargs is not None and "X_norm_squared" in metric_kwargs: - self.X_norm_squared = check_array( - metric_kwargs.pop("X_norm_squared"), - ensure_2d=False, - input_name="X_norm_squared", - dtype=np.float64, - ) - else: - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms{{name_suffix}}( - X, - self.effective_n_threads, - ) - ) - - self.use_squared_distances = use_squared_distances - - if use_squared_distances: - # In this specialisation and this setup, the value passed to the radius is - # already considered to be the adapted radius, so we overwrite it. - self.r_radius = radius - - @final - cdef void _parallel_on_X_parallel_init( - self, - ITYPE_t thread_num, - ) nogil: - RadiusNeighbors{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num) - self.middle_term_computer._parallel_on_X_parallel_init(thread_num) - - @final - cdef void _parallel_on_X_init_chunk( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - RadiusNeighbors{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) - self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - RadiusNeighbors{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num, - ) - - @final - cdef void _parallel_on_Y_init( - self, - ) nogil: - RadiusNeighbors{{name_suffix}}._parallel_on_Y_init(self) - self.middle_term_computer._parallel_on_Y_init() - - @final - cdef void _parallel_on_Y_parallel_init( - self, - ITYPE_t thread_num, - ITYPE_t X_start, - ITYPE_t X_end, - ) nogil: - RadiusNeighbors{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) - self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) - - @final - cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - RadiusNeighbors{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - self, - X_start, X_end, - Y_start, Y_end, - thread_num, - ) - self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( - X_start, X_end, Y_start, Y_end, thread_num - ) - - @final - cdef void compute_exact_distances(self) nogil: - if not self.use_squared_distances: - RadiusNeighbors{{name_suffix}}.compute_exact_distances(self) - - @final - cdef void _compute_and_reduce_distances_on_chunks( - self, - ITYPE_t X_start, - ITYPE_t X_end, - ITYPE_t Y_start, - ITYPE_t Y_end, - ITYPE_t thread_num, - ) nogil: - cdef: - ITYPE_t i, j - DTYPE_t sqeuclidean_dist_i_j - ITYPE_t n_X = X_end - X_start - ITYPE_t n_Y = Y_end - Y_start - DTYPE_t *dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms( - X_start, X_end, Y_start, Y_end, thread_num + bint need_to_compute_exact_dist = ( + self.datasets_pair.need_to_compute_exact_dist() ) - - # Pushing the distance and their associated indices in vectors. - for i in range(n_X): - for j in range(n_Y): - sqeuclidean_dist_i_j = ( - self.X_norm_squared[i + X_start] - + dist_middle_terms[i * n_Y + j] - + self.Y_norm_squared[j + Y_start] - ) - - # Catastrophic cancellation might cause -0. to be present, - # e.g. when computing d(x_i, y_i) when X is Y. - sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j) - - if sqeuclidean_dist_i_j <= self.r_radius: - deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(sqeuclidean_dist_i_j) - deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) + if need_to_compute_exact_dist: + for i in prange(self.n_samples_X, nogil=True, schedule='static', + num_threads=self.effective_n_threads): + for j in range(deref(self.neigh_indices)[i].size()): + deref(self.neigh_distances)[i][j] = ( + self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against potential -0., causing nan production. + max(deref(self.neigh_distances)[i][j], 0.) + ) + ) {{endfor}} diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 4fe8013cd3602..06113581adb9e 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -637,7 +637,7 @@ def test_argkmin_factory_method_wrong_usages(): "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), } - message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + message = r"Some metric_kwargs have been passed \({'p': 3" with pytest.warns(UserWarning, match=message): ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) @@ -723,7 +723,7 @@ def test_radius_neighbors_factory_method_wrong_usages(): "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), } - message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + message = r"Some metric_kwargs have been passed \({'p': 3" with pytest.warns(UserWarning, match=message): RadiusNeighbors.compute(