Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 100949f

Browse files
committed
DOC Better document implementation
Signed-off-by: Julien Jerphanion <[email protected]>
1 parent 22d2cd2 commit 100949f

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

doc/whats_new/v1.3.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ Changes impacting all modules
8080
Euclidean distances (sometimes followed by a fused reduction operation) for a
8181
pair of datasets consisting of a sparse CSR matrix and a dense NumPy.
8282

83-
This can improve the performance following functions and estimators.
84-
83+
This can improve the performance of following functions and estimators:
84+
8585
- :func:`sklearn.metrics.pairwise_distances_argmin`
8686
- :func:`sklearn.metrics.pairwise_distances_argmin_min`
8787
- :class:`sklearn.cluster.AffinityPropagation`
@@ -105,7 +105,7 @@ Changes impacting all modules
105105

106106
A typical example of this performance improvement happens when passing a sparse
107107
CSR matrix to the `predict` or `transform` method of estimators that rely on
108-
a dense numpy representation to store their fitted parameters (or the reverse).
108+
a dense NumPy representation to store their fitted parameters (or the reverse).
109109

110110
For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is up to 2 times faster
111111
for this case now on commonly available laptops.

sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
194194

195195
const {{INPUT_DTYPE_t}}[:, ::1] Y
196196

197+
# We treat the dense-sparse case with the sparse-dense case by simply
198+
# treating the dist_middle_terms as F-ordered and by swapping arguments.
199+
# This attribute is meant to encode the case and adapt the logic
200+
# accordingly.
197201
bint c_ordered_middle_term
198202

199203
cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(

sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ cdef void _middle_term_sparse_sparse_64(
4949
const SPARSE_INDEX_TYPE_t[:] Y_indptr,
5050
ITYPE_t Y_start,
5151
ITYPE_t Y_end,
52-
DTYPE_t * D,
52+
DTYPE_t * dist_middle_terms,
5353
) noexcept nogil:
54-
# This routine assumes that D points to the first element of a
54+
# This routine assumes that D is a pointer to the first element of a
5555
# zeroed buffer of length at least equal to n_X × n_Y, conceptually
5656
# representing a 2-d C-ordered array.
5757
cdef:
@@ -68,7 +68,7 @@ cdef void _middle_term_sparse_sparse_64(
6868
for y_ptr in range(Y_indptr[Y_start+j], Y_indptr[Y_start+j+1]):
6969
y_col = Y_indices[y_ptr]
7070
if x_col == y_col:
71-
D[k] += -2 * X_data[x_ptr] * Y_data[y_ptr]
71+
dist_middle_terms[k] += -2 * X_data[x_ptr] * Y_data[y_ptr]
7272

7373

7474
{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
@@ -83,11 +83,11 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}(
8383
ITYPE_t Y_start,
8484
ITYPE_t Y_end,
8585
bint c_ordered_middle_term,
86-
DTYPE_t * D,
86+
DTYPE_t * dist_middle_terms,
8787
) nogil:
88-
# This routine assumes that D points to the first element of a
89-
# zeroed buffer of length at least equal to n_X × n_Y, conceptually
90-
# representing a 2-d C-ordered array.
88+
# This routine assumes that dist_middle_terms is a pointer to the first element
89+
# of a zeroed buffer of length at least equal to n_X × n_Y, conceptually
90+
# representing a 2-d C-ordered of F-ordered array.
9191
cdef:
9292
ITYPE_t i, j, k
9393
ITYPE_t n_X = X_end - X_start
@@ -99,7 +99,7 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}(
9999
k = i * n_Y + j if c_ordered_middle_term else j * n_X + i
100100
for X_i_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]):
101101
X_i_col_idx = X_indices[X_i_ptr]
102-
D[k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx]
102+
dist_middle_terms[k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx]
103103

104104

105105
cdef class MiddleTermComputer{{name_suffix}}:
@@ -183,9 +183,16 @@ cdef class MiddleTermComputer{{name_suffix}}:
183183
c_ordered_middle_term=True
184184
)
185185
if not X_is_sparse and Y_is_sparse:
186+
# NOTE: The Dense-Sparse case is implement via the Sparse-Dense case.
187+
#
188+
# To do so:
189+
# - X (dense) and Y (sparse) are swapped
190+
# - the distance middle term is seen as F-ordered for consistency
191+
# (c_ordered_middle_term = False)
186192
return SparseDenseMiddleTermComputer{{name_suffix}}(
187-
Y,
188-
X,
193+
# Mind that X and Y are swapped here.
194+
X=Y,
195+
Y=X,
189196
effective_n_threads,
190197
chunks_n_threads,
191198
dist_middle_terms_chunks_size,
@@ -572,7 +579,8 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
572579
ITYPE_t Y_end,
573580
ITYPE_t thread_num,
574581
) noexcept nogil:
575-
# Flush the thread dist_middle_terms_chunks to 0.0
582+
# Fill the thread's dist_middle_terms_chunks with 0.0 before
583+
# computing its elements in _compute_dist_middle_terms.
576584
fill(
577585
self.dist_middle_terms_chunks[thread_num].begin(),
578586
self.dist_middle_terms_chunks[thread_num].end(),
@@ -587,7 +595,8 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
587595
ITYPE_t Y_end,
588596
ITYPE_t thread_num,
589597
) noexcept nogil:
590-
# Flush the thread dist_middle_terms_chunks to 0.0
598+
# Fill the thread's dist_middle_terms_chunks with 0.0 before
599+
# computing its elements in _compute_dist_middle_terms.
591600
fill(
592601
self.dist_middle_terms_chunks[thread_num].begin(),
593602
self.dist_middle_terms_chunks[thread_num].end(),
@@ -607,15 +616,22 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
607616
self.dist_middle_terms_chunks[thread_num].data()
608617
)
609618

619+
# For the dense-sparse case, we use the sparse-dense case
620+
# with dist_middle_terms seen as F-ordered.
621+
# Hence we swap indices pointers here.
622+
if not self.c_ordered_middle_term:
623+
X_start, Y_start = Y_start, X_start
624+
X_end, Y_end = Y_end, X_end
625+
610626
_middle_term_sparse_dense_{{name_suffix}}(
611627
self.X_data,
612628
self.X_indices,
613629
self.X_indptr,
614-
X_start if self.c_ordered_middle_term else Y_start,
615-
X_end if self.c_ordered_middle_term else Y_end,
630+
X_start,
631+
X_end,
616632
self.Y,
617-
Y_start if self.c_ordered_middle_term else X_start,
618-
Y_end if self.c_ordered_middle_term else X_end,
633+
Y_start,
634+
Y_end,
619635
self.c_ordered_middle_term,
620636
dist_middle_terms,
621637
)

0 commit comments

Comments
 (0)