Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 379ffae

Browse files
committed
Revert "Do not slice memoryviews in _compute_dist_middle_terms"
This reverts commit f2e917b.
1 parent f2e917b commit 379ffae

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ cdef class GEMMTermComputer{{name_suffix}}:
175175
ITYPE_t thread_num,
176176
) nogil:
177177
cdef:
178+
const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
179+
const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
178180
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()
179181

180182
# Careful: LDA, LDB and LDC are given for F-ordered arrays
@@ -185,9 +187,9 @@ cdef class GEMMTermComputer{{name_suffix}}:
185187
BLAS_Order order = RowMajor
186188
BLAS_Trans ta = NoTrans
187189
BLAS_Trans tb = Trans
188-
ITYPE_t m = X_end - X_start
189-
ITYPE_t n = Y_end - Y_start
190-
ITYPE_t K = self.n_features
190+
ITYPE_t m = X_c.shape[0]
191+
ITYPE_t n = Y_c.shape[0]
192+
ITYPE_t K = X_c.shape[1]
191193
DTYPE_t alpha = - 2.
192194
{{if upcast_to_float64}}
193195
DTYPE_t * A = self.X_c_upcast[thread_num].data()
@@ -196,15 +198,15 @@ cdef class GEMMTermComputer{{name_suffix}}:
196198
# Casting for A and B to remove the const is needed because APIs exposed via
197199
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
198200
# See: https://github.com/scipy/scipy/issues/14262
199-
DTYPE_t * A = <DTYPE_t *> &self.X[X_start, 0]
200-
DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start, 0]
201+
DTYPE_t * A = <DTYPE_t *> &X_c[0, 0]
202+
DTYPE_t * B = <DTYPE_t *> &Y_c[0, 0]
201203
{{endif}}
202-
ITYPE_t lda = self.n_features
203-
ITYPE_t ldb = self.n_features
204+
ITYPE_t lda = X_c.shape[1]
205+
ITYPE_t ldb = X_c.shape[1]
204206
DTYPE_t beta = 0.
205-
ITYPE_t ldc = Y_end - Y_start
207+
ITYPE_t ldc = Y_c.shape[0]
206208

207-
# dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T`
209+
# dist_middle_terms = `-2 * X_c @ Y_c.T`
208210
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)
209211

210212
return dist_middle_terms

0 commit comments

Comments
 (0)