@@ -49,9 +49,9 @@ cdef void _middle_term_sparse_sparse_64(
49
49
const SPARSE_INDEX_TYPE_t[:] Y_indptr,
50
50
ITYPE_t Y_start,
51
51
ITYPE_t Y_end,
52
- DTYPE_t * D ,
52
+ DTYPE_t * dist_middle_terms ,
53
53
) noexcept nogil:
54
- # This routine assumes that D points to the first element of a
54
+ # This routine assumes that D is a pointer to the first element of a
55
55
# zeroed buffer of length at least equal to n_X × n_Y, conceptually
56
56
# representing a 2-d C-ordered array.
57
57
cdef:
@@ -68,7 +68,7 @@ cdef void _middle_term_sparse_sparse_64(
68
68
for y_ptr in range(Y_indptr[Y_start+j], Y_indptr[Y_start+j+1]):
69
69
y_col = Y_indices[y_ptr]
70
70
if x_col == y_col:
71
- D [k] += -2 * X_data[x_ptr] * Y_data[y_ptr]
71
+ dist_middle_terms [k] += -2 * X_data[x_ptr] * Y_data[y_ptr]
72
72
73
73
74
74
{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
@@ -83,11 +83,11 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}(
83
83
ITYPE_t Y_start,
84
84
ITYPE_t Y_end,
85
85
bint c_ordered_middle_term,
86
- DTYPE_t * D ,
86
+ DTYPE_t * dist_middle_terms ,
87
87
) nogil:
88
- # This routine assumes that D points to the first element of a
89
- # zeroed buffer of length at least equal to n_X × n_Y, conceptually
90
- # representing a 2-d C-ordered array.
88
+ # This routine assumes that dist_middle_terms is a pointer to the first element
89
+ # of a zeroed buffer of length at least equal to n_X × n_Y, conceptually
90
+ # representing a 2-d C-ordered of F-ordered array.
91
91
cdef:
92
92
ITYPE_t i, j, k
93
93
ITYPE_t n_X = X_end - X_start
@@ -99,7 +99,7 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}(
99
99
k = i * n_Y + j if c_ordered_middle_term else j * n_X + i
100
100
for X_i_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]):
101
101
X_i_col_idx = X_indices[X_i_ptr]
102
- D [k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx]
102
+ dist_middle_terms [k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx]
103
103
104
104
105
105
cdef class MiddleTermComputer{{name_suffix}}:
@@ -183,9 +183,16 @@ cdef class MiddleTermComputer{{name_suffix}}:
183
183
c_ordered_middle_term=True
184
184
)
185
185
if not X_is_sparse and Y_is_sparse:
186
+ # NOTE: The Dense-Sparse case is implement via the Sparse-Dense case.
187
+ #
188
+ # To do so:
189
+ # - X (dense) and Y (sparse) are swapped
190
+ # - the distance middle term is seen as F-ordered for consistency
191
+ # (c_ordered_middle_term = False)
186
192
return SparseDenseMiddleTermComputer{{name_suffix}}(
187
- Y,
188
- X,
193
+ # Mind that X and Y are swapped here.
194
+ X=Y,
195
+ Y=X,
189
196
effective_n_threads,
190
197
chunks_n_threads,
191
198
dist_middle_terms_chunks_size,
@@ -572,7 +579,8 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
572
579
ITYPE_t Y_end,
573
580
ITYPE_t thread_num,
574
581
) noexcept nogil:
575
- # Flush the thread dist_middle_terms_chunks to 0.0
582
+ # Fill the thread's dist_middle_terms_chunks with 0.0 before
583
+ # computing its elements in _compute_dist_middle_terms.
576
584
fill(
577
585
self.dist_middle_terms_chunks[thread_num].begin(),
578
586
self.dist_middle_terms_chunks[thread_num].end(),
@@ -587,7 +595,8 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
587
595
ITYPE_t Y_end,
588
596
ITYPE_t thread_num,
589
597
) noexcept nogil:
590
- # Flush the thread dist_middle_terms_chunks to 0.0
598
+ # Fill the thread's dist_middle_terms_chunks with 0.0 before
599
+ # computing its elements in _compute_dist_middle_terms.
591
600
fill(
592
601
self.dist_middle_terms_chunks[thread_num].begin(),
593
602
self.dist_middle_terms_chunks[thread_num].end(),
@@ -607,15 +616,22 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name
607
616
self.dist_middle_terms_chunks[thread_num].data()
608
617
)
609
618
619
+ # For the dense-sparse case, we use the sparse-dense case
620
+ # with dist_middle_terms seen as F-ordered.
621
+ # Hence we swap indices pointers here.
622
+ if not self.c_ordered_middle_term:
623
+ X_start, Y_start = Y_start, X_start
624
+ X_end, Y_end = Y_end, X_end
625
+
610
626
_middle_term_sparse_dense_{{name_suffix}}(
611
627
self.X_data,
612
628
self.X_indices,
613
629
self.X_indptr,
614
- X_start if self.c_ordered_middle_term else Y_start ,
615
- X_end if self.c_ordered_middle_term else Y_end ,
630
+ X_start,
631
+ X_end,
616
632
self.Y,
617
- Y_start if self.c_ordered_middle_term else X_start ,
618
- Y_end if self.c_ordered_middle_term else X_end ,
633
+ Y_start,
634
+ Y_end,
619
635
self.c_ordered_middle_term,
620
636
dist_middle_terms,
621
637
)
0 commit comments