@@ -13,7 +13,7 @@ cimport numpy as np
13
13
from libc.math cimport floor, sqrt
14
14
from libc.stdlib cimport free, malloc
15
15
16
- from cython cimport floating, integral
16
+ from cython cimport floating
17
17
from cython.parallel cimport parallel, prange
18
18
19
19
DEF CHUNK_SIZE = 256 # number of vectors
@@ -22,7 +22,6 @@ DEF MIN_CHUNK_SAMPLES = 20
22
22
23
23
DEF FLOAT_INF = 1e36
24
24
25
- from ..neighbors._neighbors_heap cimport _simultaneous_sort, _push
26
25
from ..utils._cython_blas cimport (
27
26
BLAS_Order,
28
27
BLAS_Trans,
@@ -32,7 +31,11 @@ from ..utils._cython_blas cimport (
32
31
Trans,
33
32
_gemm,
34
33
)
34
+
35
+ from ..utils._heap cimport _simultaneous_sort, _push
35
36
from ..utils._openmp_helpers import _openmp_effective_n_threads
37
+ from ..utils._typedefs cimport ITYPE_t
38
+ from ..utils._typedefs import ITYPE
36
39
37
40
38
41
# ## argkmin helpers
@@ -43,10 +46,10 @@ cdef void _argkmin_on_chunk(
43
46
floating[::1 ] Y_sq_norms, # IN
44
47
floating * dist_middle_terms, # IN
45
48
floating * heaps_red_distances, # IN/OUT
46
- integral * heaps_indices, # IN/OUT
47
- integral k, # IN
49
+ ITYPE_t * heaps_indices, # IN/OUT
50
+ ITYPE_t k, # IN
48
51
# ID of the first element of Y_c
49
- integral Y_idx_offset,
52
+ ITYPE_t Y_idx_offset,
50
53
) nogil:
51
54
"""
52
55
Critical part of the computation of pairwise distances.
@@ -55,7 +58,7 @@ cdef void _argkmin_on_chunk(
55
58
on the gemm-trick.
56
59
"""
57
60
cdef:
58
- integral i, j
61
+ ITYPE_t i, j
59
62
# Instead of computing the full pairwise squared distances matrix,
60
63
# ||X_c - Y_c||² = ||X_c||² - 2 X_c.Y_c^T + ||Y_c||²,
61
64
# we only need to store the - 2 X_c.Y_c^T + ||Y_c||²
@@ -91,44 +94,44 @@ cdef int _argkmin_on_X(
91
94
floating[:, ::1 ] X, # IN
92
95
floating[:, ::1 ] Y, # IN
93
96
floating[::1 ] Y_sq_norms, # IN
94
- integral chunk_size, # IN
95
- integral effective_n_threads, # IN
96
- integral [:, ::1 ] argkmin_indices, # OUT
97
+ ITYPE_t chunk_size, # IN
98
+ ITYPE_t effective_n_threads, # IN
99
+ ITYPE_t [:, ::1 ] argkmin_indices, # OUT
97
100
floating[:, ::1 ] argkmin_red_distances, # OUT
98
101
) nogil:
99
102
""" Computes the argkmin of each vector (row) of X on Y
100
103
by parallelising computation on chunks of X.
101
104
"""
102
105
cdef:
103
- integral k = argkmin_indices.shape[1 ]
104
- integral d = X.shape[1 ]
105
- integral sf = sizeof(floating)
106
- integral si = sizeof(integral )
107
- integral n_samples_chunk = max (MIN_CHUNK_SAMPLES, chunk_size)
108
-
109
- integral n_train = Y.shape[0 ]
110
- integral Y_n_samples_chunk = min (n_train, n_samples_chunk)
111
- integral Y_n_full_chunks = n_train / Y_n_samples_chunk
112
- integral Y_n_samples_rem = n_train % Y_n_samples_chunk
113
-
114
- integral n_test = X.shape[0 ]
115
- integral X_n_samples_chunk = min (n_test, n_samples_chunk)
116
- integral X_n_full_chunks = n_test // X_n_samples_chunk
117
- integral X_n_samples_rem = n_test % X_n_samples_chunk
106
+ ITYPE_t k = argkmin_indices.shape[1 ]
107
+ ITYPE_t d = X.shape[1 ]
108
+ ITYPE_t sf = sizeof(floating)
109
+ ITYPE_t si = sizeof(ITYPE_t )
110
+ ITYPE_t n_samples_chunk = max (MIN_CHUNK_SAMPLES, chunk_size)
111
+
112
+ ITYPE_t n_train = Y.shape[0 ]
113
+ ITYPE_t Y_n_samples_chunk = min (n_train, n_samples_chunk)
114
+ ITYPE_t Y_n_full_chunks = n_train / Y_n_samples_chunk
115
+ ITYPE_t Y_n_samples_rem = n_train % Y_n_samples_chunk
116
+
117
+ ITYPE_t n_test = X.shape[0 ]
118
+ ITYPE_t X_n_samples_chunk = min (n_test, n_samples_chunk)
119
+ ITYPE_t X_n_full_chunks = n_test // X_n_samples_chunk
120
+ ITYPE_t X_n_samples_rem = n_test % X_n_samples_chunk
118
121
119
122
# Counting remainder chunk in total number of chunks
120
- integral Y_n_chunks = Y_n_full_chunks + (
123
+ ITYPE_t Y_n_chunks = Y_n_full_chunks + (
121
124
n_train != (Y_n_full_chunks * Y_n_samples_chunk)
122
125
)
123
126
124
- integral X_n_chunks = X_n_full_chunks + (
127
+ ITYPE_t X_n_chunks = X_n_full_chunks + (
125
128
n_test != (X_n_full_chunks * X_n_samples_chunk)
126
129
)
127
130
128
- integral num_threads = min (Y_n_chunks, effective_n_threads)
131
+ ITYPE_t num_threads = min (Y_n_chunks, effective_n_threads)
129
132
130
- integral Y_start, Y_end, X_start, X_end
131
- integral X_chunk_idx, Y_chunk_idx, idx, jdx
133
+ ITYPE_t Y_start, Y_end, X_start, X_end
134
+ ITYPE_t X_chunk_idx, Y_chunk_idx, idx, jdx
132
135
133
136
floating * dist_middle_terms_chunks
134
137
floating * heaps_red_distances_chunks
@@ -190,9 +193,9 @@ cdef int _argkmin_on_Y(
190
193
floating[:, ::1 ] X, # IN
191
194
floating[:, ::1 ] Y, # IN
192
195
floating[::1 ] Y_sq_norms, # IN
193
- integral chunk_size, # IN
194
- integral effective_n_threads, # IN
195
- integral [:, ::1 ] argkmin_indices, # OUT
196
+ ITYPE_t chunk_size, # IN
197
+ ITYPE_t effective_n_threads, # IN
198
+ ITYPE_t [:, ::1 ] argkmin_indices, # OUT
196
199
floating[:, ::1 ] argkmin_red_distances, # OUT
197
200
) nogil:
198
201
""" Computes the argkmin of each vector (row) of X on Y
@@ -203,43 +206,43 @@ cdef int _argkmin_on_Y(
203
206
most contexts.
204
207
"""
205
208
cdef:
206
- integral k = argkmin_indices.shape[1 ]
207
- integral d = X.shape[1 ]
208
- integral sf = sizeof(floating)
209
- integral si = sizeof(integral )
210
- integral n_samples_chunk = max (MIN_CHUNK_SAMPLES, chunk_size)
211
-
212
- integral n_train = Y.shape[0 ]
213
- integral Y_n_samples_chunk = min (n_train, n_samples_chunk)
214
- integral Y_n_full_chunks = n_train / Y_n_samples_chunk
215
- integral Y_n_samples_rem = n_train % Y_n_samples_chunk
216
-
217
- integral n_test = X.shape[0 ]
218
- integral X_n_samples_chunk = min (n_test, n_samples_chunk)
219
- integral X_n_full_chunks = n_test // X_n_samples_chunk
220
- integral X_n_samples_rem = n_test % X_n_samples_chunk
209
+ ITYPE_t k = argkmin_indices.shape[1 ]
210
+ ITYPE_t d = X.shape[1 ]
211
+ ITYPE_t sf = sizeof(floating)
212
+ ITYPE_t si = sizeof(ITYPE_t )
213
+ ITYPE_t n_samples_chunk = max (MIN_CHUNK_SAMPLES, chunk_size)
214
+
215
+ ITYPE_t n_train = Y.shape[0 ]
216
+ ITYPE_t Y_n_samples_chunk = min (n_train, n_samples_chunk)
217
+ ITYPE_t Y_n_full_chunks = n_train / Y_n_samples_chunk
218
+ ITYPE_t Y_n_samples_rem = n_train % Y_n_samples_chunk
219
+
220
+ ITYPE_t n_test = X.shape[0 ]
221
+ ITYPE_t X_n_samples_chunk = min (n_test, n_samples_chunk)
222
+ ITYPE_t X_n_full_chunks = n_test // X_n_samples_chunk
223
+ ITYPE_t X_n_samples_rem = n_test % X_n_samples_chunk
221
224
222
225
# Counting remainder chunk in total number of chunks
223
- integral Y_n_chunks = Y_n_full_chunks + (
226
+ ITYPE_t Y_n_chunks = Y_n_full_chunks + (
224
227
n_train != (Y_n_full_chunks * Y_n_samples_chunk)
225
228
)
226
229
227
- integral X_n_chunks = X_n_full_chunks + (
230
+ ITYPE_t X_n_chunks = X_n_full_chunks + (
228
231
n_test != (X_n_full_chunks * X_n_samples_chunk)
229
232
)
230
233
231
- integral num_threads = min (Y_n_chunks, effective_n_threads)
234
+ ITYPE_t num_threads = min (Y_n_chunks, effective_n_threads)
232
235
233
- integral Y_start, Y_end, X_start, X_end
234
- integral X_chunk_idx, Y_chunk_idx, idx, jdx
236
+ ITYPE_t Y_start, Y_end, X_start, X_end
237
+ ITYPE_t X_chunk_idx, Y_chunk_idx, idx, jdx
235
238
236
239
floating * dist_middle_terms_chunks
237
240
floating * heaps_red_distances_chunks
238
241
239
242
# As chunks of X are shared across threads, so must their
240
243
# heaps. To solve this, each thread has its own locals
241
244
# heaps which are then synchronised back in the main ones.
242
- integral * heaps_indices_chunks
245
+ ITYPE_t * heaps_indices_chunks
243
246
244
247
for X_chunk_idx in range (X_n_chunks):
245
248
X_start = X_chunk_idx * X_n_samples_chunk
@@ -256,7 +259,7 @@ cdef int _argkmin_on_Y(
256
259
Y_n_samples_chunk * X_n_samples_chunk * sf)
257
260
heaps_red_distances_chunks = < floating* > malloc(
258
261
X_n_samples_chunk * k * sf)
259
- heaps_indices_chunks = < integral * > malloc(
262
+ heaps_indices_chunks = < ITYPE_t * > malloc(
260
263
X_n_samples_chunk * k * sf)
261
264
262
265
# Initialising heaps (memset can't be used here)
@@ -318,13 +321,13 @@ cdef int _argkmin_on_Y(
318
321
cdef inline floating _euclidean_dist(
319
322
floating[:, ::1 ] X,
320
323
floating[:, ::1 ] Y,
321
- integral i,
322
- integral j,
324
+ ITYPE_t i,
325
+ ITYPE_t j,
323
326
) nogil:
324
327
cdef:
325
328
floating dist = 0
326
- integral k
327
- integral upper_unrolled_idx = (X.shape[1 ] // 4 ) * 4
329
+ ITYPE_t k
330
+ ITYPE_t upper_unrolled_idx = (X.shape[1 ] // 4 ) * 4
328
331
329
332
# Unrolling loop to help with vectorisation
330
333
for k in range (0 , upper_unrolled_idx, 4 ):
@@ -341,8 +344,8 @@ cdef inline floating _euclidean_dist(
341
344
cdef int _exact_euclidean_dist(
342
345
floating[:, ::1 ] X, # IN
343
346
floating[:, ::1 ] Y, # IN
344
- integral [:, ::1 ] Y_indices, # IN
345
- integral effective_n_threads, # IN
347
+ ITYPE_t [:, ::1 ] Y_indices, # IN
348
+ ITYPE_t effective_n_threads, # IN
346
349
floating[:, ::1 ] distances, # OUT
347
350
) nogil:
348
351
"""
@@ -356,7 +359,7 @@ cdef int _exact_euclidean_dist(
356
359
but we use a function to have a cdef nogil context.
357
360
"""
358
361
cdef:
359
- integral i, k
362
+ ITYPE_t i, k
360
363
361
364
for i in prange(X.shape[0 ], schedule = ' static' ,
362
365
nogil = True , num_threads = effective_n_threads):
@@ -370,8 +373,8 @@ cdef int _exact_euclidean_dist(
370
373
def _argkmin (
371
374
floating[:, ::1] X ,
372
375
floating[:, ::1] Y ,
373
- integral k ,
374
- integral chunk_size = CHUNK_SIZE,
376
+ ITYPE_t k ,
377
+ ITYPE_t chunk_size = CHUNK_SIZE,
375
378
str strategy = " auto" ,
376
379
bint return_distance = False ,
377
380
):
@@ -419,13 +422,13 @@ def _argkmin(
419
422
int_dtype = np.intp
420
423
float_dtype = np.float32 if floating is float else np.float64
421
424
cdef:
422
- integral [:, ::1 ] argkmin_indices = np.full((X.shape[0 ], k), 0 ,
423
- dtype = int_dtype )
425
+ ITYPE_t [:, ::1 ] argkmin_indices = np.full((X.shape[0 ], k), 0 ,
426
+ dtype = ITYPE )
424
427
floating[:, ::1 ] argkmin_distances = np.full((X.shape[0 ], k),
425
428
FLOAT_INF,
426
429
dtype = float_dtype)
427
430
floating[::1 ] Y_sq_norms = np.einsum(' ij,ij->i' , Y, Y)
428
- integral effective_n_threads = _openmp_effective_n_threads()
431
+ ITYPE_t effective_n_threads = _openmp_effective_n_threads()
429
432
430
433
if strategy == ' auto' :
431
434
# This is a simple heuristic whose constant for the
0 commit comments