From 49cf7d112af90a047ee9582ffb7a8b60582f1b01 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 17 Oct 2022 16:45:44 +0200 Subject: [PATCH] Do not slice memoryviews in _compute_dist_middle_terms See the reasons here: https://github.com/scikit-learn/scikit-learn/issues/17299 --- .../_gemm_term_computer.pyx.tp | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp index e69d1c3df9f7d..e040417cbe705 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp @@ -175,8 +175,6 @@ cdef class GEMMTermComputer{{name_suffix}}: ITYPE_t thread_num, ) nogil: cdef: - const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :] - const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :] DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() # Careful: LDA, LDB and LDC are given for F-ordered arrays @@ -187,9 +185,9 @@ cdef class GEMMTermComputer{{name_suffix}}: BLAS_Order order = RowMajor BLAS_Trans ta = NoTrans BLAS_Trans tb = Trans - ITYPE_t m = X_c.shape[0] - ITYPE_t n = Y_c.shape[0] - ITYPE_t K = X_c.shape[1] + ITYPE_t m = X_end - X_start + ITYPE_t n = Y_end - Y_start + ITYPE_t K = self.n_features DTYPE_t alpha = - 2. {{if upcast_to_float64}} DTYPE_t * A = self.X_c_upcast[thread_num].data() @@ -198,15 +196,15 @@ cdef class GEMMTermComputer{{name_suffix}}: # Casting for A and B to remove the const is needed because APIs exposed via # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * A = &X_c[0, 0] - DTYPE_t * B = &Y_c[0, 0] + DTYPE_t * A = &self.X[X_start, 0] + DTYPE_t * B = &self.Y[Y_start, 0] {{endif}} - ITYPE_t lda = X_c.shape[1] - ITYPE_t ldb = X_c.shape[1] + ITYPE_t lda = self.n_features + ITYPE_t ldb = self.n_features DTYPE_t beta = 0. - ITYPE_t ldc = Y_c.shape[0] + ITYPE_t ldc = Y_end - Y_start - # dist_middle_terms = `-2 * X_c @ Y_c.T` + # dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T` _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) return dist_middle_terms