diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 29ac839187fc9..5fbe785909096 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -968,8 +968,102 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return np.asarray(self.argkmin_indices) +cdef class GEMMTermComputer: + """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. + + `FastEuclidean*` classes internally compute the squared Euclidean distances between + chunks of vectors X_c and Y_c using using the decomposition: + + + ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + + + This helper class is in charge of wrapping the common logic to compute + the middle term `- 2 X_c_i.Y_c_j^T` with a call to GEMM, which has a high + arithmetic intensity. + """ + + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + ITYPE_t dist_middle_terms_chunks_size + + # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM + vector[vector[DTYPE_t]] dist_middle_terms_chunks + + def __init__(self, + DTYPE_t[:, ::1] X, + DTYPE_t[:, ::1] Y, + ITYPE_t effective_n_threads, + ITYPE_t chunks_n_threads, + ITYPE_t dist_middle_terms_chunks_size, + ): + self.X = X + self.Y = Y + self.effective_n_threads = effective_n_threads + self.chunks_n_threads = chunks_n_threads + self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size + + self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) + + cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: + self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) + + cdef void _parallel_on_Y_init(self) nogil: + for thread_num in range(self.chunks_n_threads): + self.dist_middle_terms_chunks[thread_num].resize( + self.dist_middle_terms_chunks_size + ) + + cdef DTYPE_t * _compute_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * A = &X_c[0, 0] + ITYPE_t lda = X_c.shape[1] + DTYPE_t * B = &Y_c[0, 0] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) + + return dist_middle_terms + + cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): - """Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance. + """Fast specialized variant for PairwiseDistancesArgKmin on EuclideanDistance. The full pairwise squared distances matrix is computed as follows: @@ -980,19 +1074,16 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): Notes ----- This implementation has a superior arithmetic intensity and hence - better running time when the alternative is IO bound, but it can suffer + better running time when the variant is IO bound, but it can suffer from numerical instability caused by catastrophic cancellation potentially introduced by the subtraction in the arithmetic expression above. """ cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y + GEMMTermComputer gemm_term_computer const DTYPE_t[::1] X_norm_squared const DTYPE_t[::1] Y_norm_squared - # Buffers for GEMM - DTYPE_t ** dist_middle_terms_chunks bint use_squared_distances @classmethod @@ -1028,29 +1119,28 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair cdef: DenseDenseDatasetsPair datasets_pair = self.datasets_pair - self.X, self.Y = datasets_pair.X, datasets_pair.Y + ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk + + self.gemm_term_computer = GEMMTermComputer( + datasets_pair.X, + datasets_pair.Y, + self.effective_n_threads, + self.chunks_n_threads, + dist_middle_terms_chunks_size, + ) if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: - self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms(datasets_pair.Y, self.effective_n_threads) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms(self.X, self.effective_n_threads) + _sqeuclidean_row_norms(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances - # Temporary datastructures used in threads - self.dist_middle_terms_chunks = malloc( - sizeof(DTYPE_t *) * self.chunks_n_threads - ) - - def __dealloc__(self): - if self.dist_middle_terms_chunks is not NULL: - free(self.dist_middle_terms_chunks) - @final cdef void compute_exact_distances(self) nogil: if not self.use_squared_distances: @@ -1062,19 +1152,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ITYPE_t thread_num, ) nogil: PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num) - - # Temporary buffer for the `-2 * X_c @ Y_c.T` term - self.dist_middle_terms_chunks[thread_num] = malloc( - self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) - ) - - @final - cdef void _parallel_on_X_parallel_finalize( - self, - ITYPE_t thread_num - ) nogil: - PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self, thread_num) - free(self.dist_middle_terms_chunks[thread_num]) + self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) @final cdef void _parallel_on_Y_init( @@ -1082,22 +1160,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) nogil: cdef ITYPE_t thread_num PairwiseDistancesArgKmin._parallel_on_Y_init(self) - - for thread_num in range(self.chunks_n_threads): - # Temporary buffer for the `-2 * X_c @ Y_c.T` term - self.dist_middle_terms_chunks[thread_num] = malloc( - self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) - ) - - @final - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - cdef ITYPE_t thread_num - PairwiseDistancesArgKmin._parallel_on_Y_finalize(self) - - for thread_num in range(self.chunks_n_threads): - free(self.dist_middle_terms_chunks[thread_num]) + self.gemm_term_computer._parallel_on_Y_init() @final cdef void _compute_and_reduce_distances_on_chunks( @@ -1110,42 +1173,20 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ) nogil: cdef: ITYPE_t i, j + DTYPE_t squared_dist_i_j + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t * dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) + DTYPE_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t * heaps_indices = self.heaps_indices_chunks[thread_num] - const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] - const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] - DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] - ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] - - # Careful: LDA, LDB and LDC are given for F-ordered arrays - # in BLAS documentations, for instance: - # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa - # - # Here, we use their counterpart values to work with C-ordered arrays. - BLAS_Order order = RowMajor - BLAS_Trans ta = NoTrans - BLAS_Trans tb = Trans - ITYPE_t m = X_c.shape[0] - ITYPE_t n = Y_c.shape[0] - ITYPE_t K = X_c.shape[1] - DTYPE_t alpha = - 2. - # Casting for A and B to remove the const is needed because APIs exposed via - # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * A = & X_c[0, 0] - ITYPE_t lda = X_c.shape[1] - DTYPE_t * B = & Y_c[0, 0] - ITYPE_t ldb = X_c.shape[1] - DTYPE_t beta = 0. - ITYPE_t ldc = Y_c.shape[0] - - # dist_middle_terms = `-2 * X_c @ Y_c.T` - _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) # Pushing the distance and their associated indices on heaps # which keep tracks of the argkmin. - for i in range(X_c.shape[0]): - for j in range(Y_c.shape[0]): + for i in range(n_X): + for j in range(n_Y): heap_push( heaps_r_distances + i * self.k, heaps_indices + i * self.k, @@ -1156,7 +1197,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): # ( self.X_norm_squared[i + X_start] + - dist_middle_terms[i * Y_c.shape[0] + j] + + dist_middle_terms[i * n_Y + j] + self.Y_norm_squared[j + Y_start] ), j + Y_start, @@ -1566,7 +1607,7 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRadiusNeighborhood): - """Fast specialized alternative for PairwiseDistancesRadiusNeighborhood on EuclideanDistance. + """Fast specialized variant for PairwiseDistancesRadiusNeighborhood on EuclideanDistance. The full pairwise squared distances matrix is computed as follows: @@ -1577,20 +1618,17 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad Notes ----- This implementation has a superior arithmetic intensity and hence - better running time when the alternative is IO bound, but it can suffer + better running time when the variant is IO bound, but it can suffer from numerical instability caused by catastrophic cancellation potentially introduced by the subtraction in the arithmetic expression above. numerical precision is needed. """ cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y + GEMMTermComputer gemm_term_computer const DTYPE_t[::1] X_norm_squared const DTYPE_t[::1] Y_norm_squared - # Buffers for GEMM - vector[vector[DTYPE_t]] dist_middle_terms_chunks bint use_squared_distances @classmethod @@ -1621,17 +1659,25 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair cdef: DenseDenseDatasetsPair datasets_pair = self.datasets_pair - self.X, self.Y = datasets_pair.X, datasets_pair.Y + ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk + + self.gemm_term_computer = GEMMTermComputer( + datasets_pair.X, + datasets_pair.Y, + self.effective_n_threads, + self.chunks_n_threads, + dist_middle_terms_chunks_size, + ) if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") else: - self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms(datasets_pair.Y, self.effective_n_threads) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms(self.X, self.effective_n_threads) + _sqeuclidean_row_norms(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances @@ -1640,11 +1686,6 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad # already considered to be the adapted radius, so we overwrite it. self.r_radius = radius - # Temporary datastructures used in threads - self.dist_middle_terms_chunks = vector[vector[DTYPE_t]]( - self.effective_n_threads - ) - @final cdef void compute_exact_distances(self) nogil: if not self.use_squared_distances: @@ -1656,11 +1697,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad ITYPE_t thread_num, ) nogil: PairwiseDistancesRadiusNeighborhood._parallel_on_X_parallel_init(self, thread_num) - - # Temporary buffer for the `-2 * X_c @ Y_c.T` term - self.dist_middle_terms_chunks[thread_num].resize( - self.Y_n_samples_chunk * self.X_n_samples_chunk - ) + self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) @final cdef void _parallel_on_Y_init( @@ -1668,12 +1705,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad ) nogil: cdef ITYPE_t thread_num PairwiseDistancesRadiusNeighborhood._parallel_on_Y_init(self) - - for thread_num in range(self.chunks_n_threads): - # Temporary buffer for the `-2 * X_c @ Y_c.T` term - self.dist_middle_terms_chunks[thread_num].resize( - self.Y_n_samples_chunk * self.X_n_samples_chunk - ) + self.gemm_term_computer._parallel_on_Y_init() @final cdef void _compute_and_reduce_distances_on_chunks( @@ -1687,46 +1719,22 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad cdef: ITYPE_t i, j DTYPE_t squared_dist_i_j - - const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] - const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() - - # Careful: LDA, LDB and LDC are given for F-ordered arrays - # in BLAS documentations, for instance: - # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa - # - # Here, we use their counterpart values to work with C-ordered arrays. - BLAS_Order order = RowMajor - BLAS_Trans ta = NoTrans - BLAS_Trans tb = Trans - ITYPE_t m = X_c.shape[0] - ITYPE_t n = Y_c.shape[0] - ITYPE_t K = X_c.shape[1] - DTYPE_t alpha = - 2. - # Casting for A and B to remove the const is needed because APIs exposed via - # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * A = &X_c[0, 0] - ITYPE_t lda = X_c.shape[1] - DTYPE_t * B = &Y_c[0, 0] - ITYPE_t ldb = X_c.shape[1] - DTYPE_t beta = 0. - ITYPE_t ldc = Y_c.shape[0] - - # dist_middle_terms = `-2 * X_c @ Y_c.T` - _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + DTYPE_t *dist_middle_terms = self.gemm_term_computer._compute_distances_on_chunks( + X_start, X_end, Y_start, Y_end, thread_num + ) # Pushing the distance and their associated indices in vectors. - for i in range(X_c.shape[0]): - for j in range(Y_c.shape[0]): + for i in range(n_X): + for j in range(n_Y): # Using the squared euclidean distance as the rank-preserving distance: # # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² # squared_dist_i_j = ( self.X_norm_squared[i + X_start] - + dist_middle_terms[i * Y_c.shape[0] + j] + + dist_middle_terms[i * n_Y + j] + self.Y_norm_squared[j + Y_start] ) if squared_dist_i_j <= self.r_radius: