diff --git a/sklearn/metrics/_dist_metrics.pxd.tp b/sklearn/metrics/_dist_metrics.pxd.tp index ef23f2af50ffb..8f8aa21107015 100644 --- a/sklearn/metrics/_dist_metrics.pxd.tp +++ b/sklearn/metrics/_dist_metrics.pxd.tp @@ -3,7 +3,7 @@ implementation_specific_values = [ # Values are the following ones: # - # name_suffix, DTYPE_t, DTYPE + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE # # On the first hand, an empty string is used for `name_suffix` # for the float64 case as to still be able to expose the original @@ -28,9 +28,9 @@ implementation_specific_values = [ cimport numpy as cnp from libc.math cimport sqrt, exp -from ..utils._typedefs cimport DTYPE_t, ITYPE_t +from ..utils._typedefs cimport DTYPE_t, ITYPE_t, SPARSE_INDEX_TYPE_t -{{for name_suffix, DTYPE_t, DTYPE in implementation_specific_values}} +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} ###################################################################### # Inline distance functions @@ -38,8 +38,8 @@ from ..utils._typedefs cimport DTYPE_t, ITYPE_t # We use these for the default (euclidean) case so that they can be # inlined. This leads to faster computation for the most common case cdef inline DTYPE_t euclidean_dist{{name_suffix}}( - const {{DTYPE_t}}* x1, - const {{DTYPE_t}}* x2, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, ITYPE_t size, ) nogil except -1: cdef DTYPE_t tmp, d=0 @@ -51,8 +51,8 @@ cdef inline DTYPE_t euclidean_dist{{name_suffix}}( cdef inline DTYPE_t euclidean_rdist{{name_suffix}}( - const {{DTYPE_t}}* x1, - const {{DTYPE_t}}* x2, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, ITYPE_t size, ) nogil except -1: cdef DTYPE_t tmp, d=0 @@ -63,11 +63,11 @@ cdef inline DTYPE_t euclidean_rdist{{name_suffix}}( return d -cdef inline DTYPE_t euclidean_dist_to_rdist{{name_suffix}}(const {{DTYPE_t}} dist) nogil except -1: +cdef inline DTYPE_t euclidean_dist_to_rdist{{name_suffix}}(const {{INPUT_DTYPE_t}} dist) nogil except -1: return dist * dist -cdef inline DTYPE_t euclidean_rdist_to_dist{{name_suffix}}(const {{DTYPE_t}} dist) nogil except -1: +cdef inline DTYPE_t euclidean_rdist_to_dist{{name_suffix}}(const {{INPUT_DTYPE_t}} dist) nogil except -1: return sqrt(dist) @@ -78,26 +78,89 @@ cdef class DistanceMetric{{name_suffix}}: # we must define them here so that cython's limited polymorphism will work. # Because we don't expect to instantiate a lot of these objects, the # extra memory overhead of this setup should not be an issue. - cdef {{DTYPE_t}} p - cdef {{DTYPE_t}}[::1] vec - cdef {{DTYPE_t}}[:, ::1] mat + cdef DTYPE_t p + cdef DTYPE_t[::1] vec + cdef DTYPE_t[:, ::1] mat cdef ITYPE_t size cdef object func cdef object kwargs - cdef DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1 - - cdef DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1 - - cdef int pdist(self, const {{DTYPE_t}}[:, ::1] X, {{DTYPE_t}}[:, ::1] D) except -1 - - cdef int cdist(self, const {{DTYPE_t}}[:, ::1] X, const {{DTYPE_t}}[:, ::1] Y, - {{DTYPE_t}}[:, ::1] D) except -1 - - cdef DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1 - - cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1 + cdef DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1 + + cdef DTYPE_t rdist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1 + + cdef DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1 + + cdef DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1 + + cdef int pdist( + self, + const {{INPUT_DTYPE_t}}[:, ::1] X, + DTYPE_t[:, ::1] D, + ) except -1 + + cdef int cdist( + self, + const {{INPUT_DTYPE_t}}[:, ::1] X, + const {{INPUT_DTYPE_t}}[:, ::1] Y, + DTYPE_t[:, ::1] D, + ) except -1 + + cdef int pdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const SPARSE_INDEX_TYPE_t[:] x1_indptr, + const ITYPE_t size, + DTYPE_t[:, ::1] D, + ) nogil except -1 + + cdef int cdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const SPARSE_INDEX_TYPE_t[:] x1_indptr, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t[:] x2_indptr, + const ITYPE_t size, + DTYPE_t[:, ::1] D, + ) nogil except -1 + + cdef DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1 + + cdef DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1 {{endfor}} diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 47bd1dcbab519..a7574bff86510 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -3,7 +3,7 @@ implementation_specific_values = [ # Values are the following ones: # - # name_suffix, DTYPE_t, DTYPE + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE # # # On the first hand, an empty string is used for `name_suffix` @@ -85,8 +85,9 @@ def get_valid_metric_ids(L): return [key for (key, val) in METRIC_MAPPING.items() if (val.__name__ in L) or (val in L)] +from ..utils._typedefs import SPARSE_INDEX_TYPE -{{for name_suffix, DTYPE_t, DTYPE in implementation_specific_values}} +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} ###################################################################### # metric mappings @@ -119,7 +120,7 @@ METRIC_MAPPING{{name_suffix}} = { 'pyfunc': PyFuncDistance{{name_suffix}}, } -cdef inline cnp.ndarray _buffer_to_ndarray{{name_suffix}}(const {{DTYPE_t}}* x, cnp.npy_intp n): +cdef inline cnp.ndarray _buffer_to_ndarray{{name_suffix}}(const {{INPUT_DTYPE_t}}* x, cnp.npy_intp n): # Wrap a memory buffer with an ndarray. Warning: this is not robust. # In particular, if x is deallocated before the returned array goes # out of scope, this could cause memory errors. Since there is not @@ -129,7 +130,7 @@ cdef inline cnp.ndarray _buffer_to_ndarray{{name_suffix}}(const {{DTYPE_t}}* x, return PyArray_SimpleNewFromData(1, &n, DTYPECODE, x) -cdef {{DTYPE_t}} INF{{name_suffix}} = np.inf +cdef {{INPUT_DTYPE_t}} INF{{name_suffix}} = np.inf ###################################################################### @@ -249,8 +250,8 @@ cdef class DistanceMetric{{name_suffix}}: """ def __cinit__(self): self.p = 2 - self.vec = np.zeros(1, dtype={{DTYPE}}, order='C') - self.mat = np.zeros((1, 1), dtype={{DTYPE}}, order='C') + self.vec = np.zeros(1, dtype=DTYPE, order='C') + self.mat = np.zeros((1, 1), dtype=DTYPE, order='C') self.size = 1 def __reduce__(self): @@ -334,16 +335,24 @@ cdef class DistanceMetric{{name_suffix}}: """ return - cdef DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: """Compute the distance between vectors x1 and x2 This should be overridden in a base class. """ return -999 - cdef DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef DTYPE_t rdist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: """Compute the rank-preserving surrogate distance between vectors x1 and x2. This can optionally be overridden in a base class. @@ -355,8 +364,12 @@ cdef class DistanceMetric{{name_suffix}}: """ return self.dist(x1, x2, size) - cdef int pdist(self, const {{DTYPE_t}}[:, ::1] X, {{DTYPE_t}}[:, ::1] D) except -1: - """compute the pairwise distances between points in X""" + cdef int pdist( + self, + const {{INPUT_DTYPE_t}}[:, ::1] X, + DTYPE_t[:, ::1] D, + ) except -1: + """Compute the pairwise distances between points in X""" cdef ITYPE_t i1, i2 for i1 in range(X.shape[0]): for i2 in range(i1, X.shape[0]): @@ -364,9 +377,14 @@ cdef class DistanceMetric{{name_suffix}}: D[i2, i1] = D[i1, i2] return 0 - cdef int cdist(self, const {{DTYPE_t}}[:, ::1] X, const {{DTYPE_t}}[:, ::1] Y, - {{DTYPE_t}}[:, ::1] D) except -1: - """compute the cross-pairwise distances between arrays X and Y""" + + cdef int cdist( + self, + const {{INPUT_DTYPE_t}}[:, ::1] X, + const {{INPUT_DTYPE_t}}[:, ::1] Y, + DTYPE_t[:, ::1] D, + ) except -1: + """Compute the cross-pairwise distances between arrays X and Y""" cdef ITYPE_t i1, i2 if X.shape[1] != Y.shape[1]: raise ValueError('X and Y must have the same second dimension') @@ -375,11 +393,188 @@ cdef class DistanceMetric{{name_suffix}}: D[i1, i2] = self.dist(&X[i1, 0], &Y[i2, 0], X.shape[1]) return 0 - cdef DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: + cdef DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + """Compute the distance between vectors x1 and x2 represented + under the CSR format. + + This must be overridden in a subclass. + + Notes + ----- + The implementation of this method in subclasses must be robust to the + presence of explicit zeros in the CSR representation. + + An alternative signature would be: + + cdef DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + ) nogil except -1: + + Where calles would use slicing on the original CSR data and indices + memoryview: + + x1_start = X1_csr.indices_ptr[i] + x1_end = X1_csr.indices_ptr[i+1] + x2_start = X2_csr.indices_ptr[j] + x2_end = X2_csr.indices_ptr[j+1] + + self.dist_csr( + x1_data[x1_start:x1_end], + x1_indices[x1_start:x1_end], + x2_data[x2_start:x2_end], + x2_indices[x2_start:x2_end], + ) + + Yet, slicing on memoryview slows down execution as it takes the GIL. + See: https://github.com/scikit-learn/scikit-learn/issues/17299 + + Hence, to avoid slicing the data and indices arrays of the sparse + matrices containing respectively x1 and x2 (namely x{1,2}_{data,indice}) + are passed as well as their indice pointers (namely x{1,2}_{start,end}). + + For reference about the CSR format, see section 3.4 of + Saad, Y. (2003), Iterative Methods for Sparse Linear Systems, SIAM. + https://www-users.cse.umn.edu/~saad/IterMethBook_2ndEd.pdf + """ + return -999 + + cdef DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + """Distance between rows of CSR matrices x1 and x2. + + This can optionally be overridden in a subclass. + + The rank-preserving surrogate distance is any measure that yields the same + rank as the distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. + + Notes + ----- + The implementation of this method in subclasses must be robust to the + presence of explicit zeros in the CSR representation. + + More information about the motives for this method signature is given + in the docstring of dist_csr. + """ + return self.dist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + ) + + cdef int pdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const SPARSE_INDEX_TYPE_t[:] x1_indptr, + const ITYPE_t size, + DTYPE_t[:, ::1] D, + ) nogil except -1: + """Pairwise distances between rows in CSR matrix X. + + Note that this implementation is twice faster than cdist_csr(X, X) + because it leverages the symmetry of the problem. + """ + cdef: + ITYPE_t i1, i2 + ITYPE_t n_x1 = x1_indptr.shape[0] - 1 + ITYPE_t x1_start, x1_end, x2_start, x2_end + + for i1 in range(n_x1): + x1_start = x1_indptr[i1] + x1_end = x1_indptr[i1 + 1] + for i2 in range(i1, n_x1): + x2_start = x1_indptr[i2] + x2_end = x1_indptr[i2 + 1] + D[i1, i2] = D[i2, i1] = self.dist_csr( + x1_data, + x1_indices, + x1_data, + x1_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + ) + return 0 + + cdef int cdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const SPARSE_INDEX_TYPE_t[:] x1_indptr, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t[:] x2_indptr, + const ITYPE_t size, + DTYPE_t[:, ::1] D, + ) nogil except -1: + """Compute the cross-pairwise distances between arrays X and Y + represented in the CSR format.""" + cdef: + ITYPE_t i1, i2 + ITYPE_t n_x1 = x1_indptr.shape[0] - 1 + ITYPE_t n_x2 = x2_indptr.shape[0] - 1 + ITYPE_t x1_start, x1_end, x2_start, x2_end + + for i1 in range(n_x1): + x1_start = x1_indptr[i1] + x1_end = x1_indptr[i1 + 1] + for i2 in range(n_x2): + x2_start = x2_indptr[i2] + x2_end = x2_indptr[i2 + 1] + + D[i1, i2] = self.dist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + ) + return 0 + + cdef DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: """Convert the rank-preserving surrogate distance to the distance""" return rdist - cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: + cdef DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: """Convert the distance to the rank-preserving surrogate distance""" return dist @@ -423,6 +618,127 @@ cdef class DistanceMetric{{name_suffix}}: """ return dist + def _pairwise_dense_dense(self, X, Y): + cdef cnp.ndarray[{{INPUT_DTYPE_t}}, ndim=2, mode='c'] Xarr + cdef cnp.ndarray[{{INPUT_DTYPE_t}}, ndim=2, mode='c'] Yarr + cdef cnp.ndarray[DTYPE_t, ndim=2, mode='c'] Darr + + Xarr = np.asarray(X, dtype={{INPUT_DTYPE}}, order='C') + self._validate_data(Xarr) + if X is Y: + Darr = np.empty((Xarr.shape[0], Xarr.shape[0]), dtype=DTYPE, order='C') + self.pdist(Xarr, Darr) + else: + Yarr = np.asarray(Y, dtype={{INPUT_DTYPE}}, order='C') + self._validate_data(Yarr) + Darr = np.empty((Xarr.shape[0], Yarr.shape[0]), dtype=DTYPE, order='C') + self.cdist(Xarr, Yarr, Darr) + return Darr + + def _pairwise_sparse_sparse(self, X, Y): + X_csr = X.tocsr() + n_X, size = X_csr.shape + X_data = np.asarray(X_csr.data, dtype={{INPUT_DTYPE}}) + X_indices = np.asarray(X_csr.indices, dtype=SPARSE_INDEX_TYPE) + X_indptr = np.asarray(X_csr.indptr, dtype=SPARSE_INDEX_TYPE) + + if X is Y: + Darr = np.empty((n_X, n_X), dtype=DTYPE, order='C') + self.pdist_csr( + x1_data=X_data, + x1_indices=X_indices, + x1_indptr=X_indptr, + size=size, + D=Darr, + ) + else: + Y_csr = Y.tocsr() + n_Y, _ = Y_csr.shape + Y_data = np.asarray(Y_csr.data, dtype={{INPUT_DTYPE}}) + Y_indices = np.asarray(Y_csr.indices, dtype=SPARSE_INDEX_TYPE) + Y_indptr = np.asarray(Y_csr.indptr, dtype=SPARSE_INDEX_TYPE) + + Darr = np.empty((n_X, n_Y), dtype=DTYPE, order='C') + self.cdist_csr( + x1_data=X_data, + x1_indices=X_indices, + x1_indptr=X_indptr, + x2_data=Y_data, + x2_indices=Y_indices, + x2_indptr=Y_indptr, + size=size, + D=Darr, + ) + return Darr + + def _pairwise_sparse_dense(self, X, Y): + n_X, size = X.shape + X_data = np.asarray(X.data, dtype={{INPUT_DTYPE}}) + X_indices = np.asarray(X.indices, dtype=SPARSE_INDEX_TYPE) + X_indptr = np.asarray(X.indptr, dtype=SPARSE_INDEX_TYPE) + + # To avoid introducing redundant implementations for the CSR × dense array + # case, we wrap the dense array into a fake CSR datastructure and leverage + # the existing code for the CSR × CSR case. + # The true CSR representation of a dense array would require allocating + # a Y_indices matrix of shape (n_samples, n_features) with repeated + # contiguous integers from 0 to n_features - 1 on each row which would + # be very wasteful from a memory point of view. Instead we only allocate + # a single row and adapt the CSR × CSR routines to use a modulo operation + # when accessing Y_indices in order to achieve the same result without having + # to materialize the indices repetition explicitly. + + n_Y, _ = Y.shape + Y_data = Y.reshape(-1) + Y_indices = np.arange(size, dtype=SPARSE_INDEX_TYPE) + Y_indptr = np.arange( + start=0, stop=size * (n_Y + 1), step=size, dtype=SPARSE_INDEX_TYPE + ) + + Darr = np.empty((n_X, n_Y), dtype=DTYPE, order='C') + self.cdist_csr( + x1_data=X_data, + x1_indices=X_indices, + x1_indptr=X_indptr, + x2_data=Y_data, + x2_indices=Y_indices, + x2_indptr=Y_indptr, + size=size, + D=Darr, + ) + return Darr + + def _pairwise_dense_sparse(self, X, Y): + # Same remark as in _pairwise_sparse_dense. We could + # have implemented this method using _pairwise_dense_sparse, + # but this would have come with an extra copy to ensure + # c-contiguity of the result. + n_Y, size = Y.shape + Y_data = np.asarray(Y.data, dtype={{INPUT_DTYPE}}) + Y_indices = np.asarray(Y.indices, dtype=SPARSE_INDEX_TYPE) + Y_indptr = np.asarray(Y.indptr, dtype=SPARSE_INDEX_TYPE) + + n_X, _ = X.shape + X_data = X.reshape(-1) + X_indices = np.arange(size, dtype=SPARSE_INDEX_TYPE) + X_indptr = np.arange( + start=0, stop=size * (n_X + 1), step=size, dtype=SPARSE_INDEX_TYPE + ) + + Darr = np.empty((n_X, n_Y), dtype=DTYPE, order='C') + self.cdist_csr( + x1_data=X_data, + x1_indices=X_indices, + x1_indptr=X_indptr, + x2_data=Y_data, + x2_indices=Y_indices, + x2_indptr=Y_indptr, + size=size, + D=Darr, + ) + return Darr + + def pairwise(self, X, Y=None): """Compute the pairwise distances between X and Y @@ -432,36 +748,34 @@ cdef class DistanceMetric{{name_suffix}}: Parameters ---------- - X : array-like - Array of shape (Nx, D), representing Nx points in D dimensions. - Y : array-like (optional) - Array of shape (Ny, D), representing Ny points in D dimensions. + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. If not specified, then Y=X. Returns ------- - dist : ndarray - The shape (Nx, Ny) array of pairwise distances between points in - X and Y. + dist : ndarray of shape (n_samples_X, n_samples_Y) + The distance matrix of pairwise distances between points in X and Y. """ - cdef cnp.ndarray[{{DTYPE_t}}, ndim=2, mode='c'] Xarr - cdef cnp.ndarray[{{DTYPE_t}}, ndim=2, mode='c'] Yarr - cdef cnp.ndarray[{{DTYPE_t}}, ndim=2, mode='c'] Darr + X = check_array(X, accept_sparse=['csr']) - Xarr = np.asarray(X, dtype={{DTYPE}}, order='C') - self._validate_data(Xarr) if Y is None: - Darr = np.zeros((Xarr.shape[0], Xarr.shape[0]), - dtype={{DTYPE}}, order='C') - self.pdist(Xarr, Darr) + Y = X else: - Yarr = np.asarray(Y, dtype={{DTYPE}}, order='C') - self._validate_data(Yarr) - Darr = np.zeros((Xarr.shape[0], Yarr.shape[0]), - dtype={{DTYPE}}, order='C') - self.cdist(Xarr, Yarr, Darr) - return Darr + Y = check_array(Y, accept_sparse=['csr']) + + X_is_sparse = issparse(X) + Y_is_sparse = issparse(Y) + if not X_is_sparse and not Y_is_sparse: + return self._pairwise_dense_dense(X, Y) + if X_is_sparse and Y_is_sparse: + return self._pairwise_sparse_sparse(X, Y) + if X_is_sparse and not Y_is_sparse: + return self._pairwise_sparse_dense(X, Y) + return self._pairwise_dense_sparse(X, Y) #------------------------------------------------------------ # Euclidean Distance @@ -475,18 +789,24 @@ cdef class EuclideanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def __init__(self): self.p = 2 - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist(self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: return euclidean_dist{{name_suffix}}(x1, x2, size) - cdef inline DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t rdist(self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: return euclidean_rdist{{name_suffix}}(x1, x2, size) - cdef inline DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: + cdef inline DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: return sqrt(rdist) - cdef inline DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: + cdef inline DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: return dist * dist def rdist_to_dist(self, rdist): @@ -495,6 +815,87 @@ cdef class EuclideanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def dist_to_rdist(self, dist): return dist ** 2 + cdef inline DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + DTYPE_t unsquared = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + unsquared = x1_data[i1] - x2_data[i2] + d = d + (unsquared * unsquared) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + unsquared = x1_data[i1] + d = d + (unsquared * unsquared) + i1 = i1 + 1 + else: + unsquared = x2_data[i2] + d = d + (unsquared * unsquared) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + unsquared = x2_data[i2] + d = d + (unsquared * unsquared) + i2 = i2 + 1 + else: + while i1 < x1_end: + unsquared = x1_data[i1] + d = d + (unsquared * unsquared) + i1 = i1 + 1 + + return d + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + return sqrt( + self.rdist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + )) #------------------------------------------------------------ # SEuclidean Distance @@ -506,7 +907,7 @@ cdef class SEuclideanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = \sqrt{ \sum_i \frac{ (x_i - y_i) ^ 2}{V_i} } """ def __init__(self, V): - self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype={{DTYPE}})) + self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype=DTYPE)) self.size = self.vec.shape[0] self.p = 2 @@ -514,23 +915,31 @@ cdef class SEuclideanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): if X.shape[1] != self.size: raise ValueError('SEuclidean dist: size of V does not match') - cdef inline DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t rdist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t tmp, d=0 cdef cnp.intp_t j for j in range(size): - tmp = (x1[j] - x2[j]) - d += (tmp * tmp / self.vec[j]) + tmp = x1[j] - x2[j] + d += (tmp * tmp / self.vec[j]) return d - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: return sqrt(self.rdist(x1, x2, size)) - cdef inline DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: + cdef inline DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: return sqrt(rdist) - cdef inline DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: + cdef inline DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: return dist * dist def rdist_to_dist(self, rdist): @@ -539,6 +948,88 @@ cdef class SEuclideanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def dist_to_rdist(self, dist): return dist ** 2 + cdef inline DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + DTYPE_t unsquared = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + unsquared = x1_data[i1] - x2_data[i2] + d = d + (unsquared * unsquared) / self.vec[ix1] + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + unsquared = x1_data[i1] + d = d + (unsquared * unsquared) / self.vec[ix1] + i1 = i1 + 1 + else: + unsquared = x2_data[i2] + d = d + (unsquared * unsquared) / self.vec[ix2] + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + ix2 = x2_indices[i2 % len_x2_indices] + unsquared = x2_data[i2] + d = d + (unsquared * unsquared) / self.vec[ix2] + i2 = i2 + 1 + else: + while i1 < x1_end: + ix1 = x1_indices[i1 % len_x1_indices] + unsquared = x1_data[i1] + d = d + (unsquared * unsquared) / self.vec[ix1] + i1 = i1 + 1 + return d + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + return sqrt( + self.rdist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + )) #------------------------------------------------------------ # Manhattan Distance @@ -552,12 +1043,67 @@ cdef class ManhattanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def __init__(self): self.p = 1 - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t d = 0 cdef cnp.intp_t j for j in range(size): - d += fabs(x1[j] - x2[j]) + d += fabs(x1[j] - x2[j]) + return d + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d = d + fabs(x1_data[i1] - x2_data[i2]) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d = d + fabs(x1_data[i1]) + i1 = i1 + 1 + else: + d = d + fabs(x2_data[i2]) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + d = d + fabs(x2_data[i2]) + i2 = i2 + 1 + else: + while i1 < x1_end: + d = d + fabs(x1_data[i1]) + i1 = i1 + 1 + return d @@ -585,12 +1131,68 @@ cdef class ChebyshevDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def __init__(self): self.p = INF{{name_suffix}} - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t d = 0 cdef cnp.intp_t j for j in range(size): - d = fmax(d, fabs(x1[j] - x2[j])) + d = fmax(d, fabs(x1[j] - x2[j])) + return d + + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d = fmax(d, fabs(x1_data[i1] - x2_data[i2])) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d = fmax(d, fabs(x1_data[i1])) + i1 = i1 + 1 + else: + d = fmax(d, fabs(x2_data[i2])) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + d = fmax(d, fabs(x2_data[i2])) + i2 = i2 + 1 + else: + while i1 < x1_end: + d = fmax(d, fabs(x1_data[i1])) + i1 = i1 + 1 + return d @@ -631,14 +1233,14 @@ cdef class MinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): self.p = p if w is not None: w_array = check_array( - w, ensure_2d=False, dtype={{DTYPE}}, input_name="w" + w, ensure_2d=False, dtype=DTYPE, input_name="w" ) if (w_array < 0).any(): raise ValueError("w cannot contain negative weights") self.vec = ReadonlyArrayWrapper(w_array) self.size = self.vec.shape[0] else: - self.vec = ReadonlyArrayWrapper(np.asarray([], dtype={{DTYPE}})) + self.vec = ReadonlyArrayWrapper(np.asarray([], dtype=DTYPE)) self.size = 0 def _validate_data(self, X): @@ -647,28 +1249,36 @@ cdef class MinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): f"the number of features ({X.shape[1]}). " f"Currently len(w)={self.size}.") - cdef inline DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t rdist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t d=0 cdef cnp.intp_t j cdef bint has_w = self.size > 0 if has_w: for j in range(size): - d += (self.vec[j] * pow(fabs(x1[j] - x2[j]), self.p)) + d += (self.vec[j] * pow(fabs(x1[j] - x2[j]), self.p)) else: for j in range(size): - d += (pow(fabs(x1[j] - x2[j]), self.p)) + d += (pow(fabs(x1[j] - x2[j]), self.p)) return d - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - return pow(self.rdist(x1, x2, size), 1. / self.p) + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + return pow(self.rdist(x1, x2, size), 1. / self.p) - cdef inline DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: - return pow(rdist, 1. / self.p) + cdef inline DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: + return pow(rdist, 1. / self.p) - cdef inline DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: - return pow(dist, self.p) + cdef inline DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: + return pow(dist, self.p) def rdist_to_dist(self, rdist): return rdist ** (1. / self.p) @@ -676,6 +1286,120 @@ cdef class MinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def dist_to_rdist(self, dist): return dist ** self.p + cdef inline DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + bint has_w = self.size > 0 + + if has_w: + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d = d + (self.vec[ix1] * pow(fabs( + x1_data[i1] - x2_data[i2] + ), self.p)) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d = d + (self.vec[ix1] * pow(fabs(x1_data[i1]), self.p)) + i1 = i1 + 1 + else: + d = d + (self.vec[ix2] * pow(fabs(x2_data[i2]), self.p)) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + ix2 = x2_indices[i2 % len_x2_indices] + d = d + (self.vec[ix2] * pow(fabs(x2_data[i2]), self.p)) + i2 = i2 + 1 + else: + while i1 < x1_end: + ix1 = x1_indices[i1 % len_x1_indices] + d = d + (self.vec[ix1] * pow(fabs(x1_data[i1]), self.p)) + i1 = i1 + 1 + + return d + else: + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d = d + (pow(fabs( + x1_data[i1] - x2_data[i2] + ), self.p)) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d = d + (pow(fabs(x1_data[i1]), self.p)) + i1 = i1 + 1 + else: + d = d + (pow(fabs(x2_data[i2]), self.p)) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + d = d + (pow(fabs(x2_data[i2]), self.p)) + i2 = i2 + 1 + else: + while i1 < x1_end: + d = d + (pow(fabs(x1_data[i1]), self.p)) + i1 = i1 + 1 + + return d + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + return pow( + self.rdist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + ), + 1 / self.p + ) #------------------------------------------------------------ # TODO: Remove in 1.3 - WMinkowskiDistance class @@ -713,7 +1437,7 @@ cdef class WMinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): raise ValueError("WMinkowskiDistance requires finite p. " "For p=inf, use ChebyshevDistance.") self.p = p - self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype={{DTYPE}})) + self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype=DTYPE)) self.size = self.vec.shape[0] def _validate_data(self, X): @@ -721,24 +1445,32 @@ cdef class WMinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): raise ValueError('WMinkowskiDistance dist: ' 'size of w does not match') - cdef inline DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t rdist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t d = 0 cdef cnp.intp_t j for j in range(size): - d += (pow(self.vec[j] * fabs(x1[j] - x2[j]), self.p)) + d += (pow(self.vec[j] * fabs(x1[j] - x2[j]), self.p)) return d - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - return pow(self.rdist(x1, x2, size), 1. / self.p) + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + return pow(self.rdist(x1, x2, size), 1. / self.p) - cdef inline DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: - return pow(rdist, 1. / self.p) + cdef inline DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: + return pow(rdist, 1. / self.p) - cdef inline DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: - return pow(dist, self.p) + cdef inline DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: + return pow(dist, self.p) def rdist_to_dist(self, rdist): return rdist ** (1. / self.p) @@ -746,6 +1478,87 @@ cdef class WMinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def dist_to_rdist(self, dist): return dist ** self.p + cdef inline DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d = d + pow(self.vec[ix1] * fabs( + x1_data[i1] - x2_data[i2] + ), self.p) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d = d + pow(self.vec[ix1] * fabs(x1_data[i1]), self.p) + i1 = i1 + 1 + else: + d = d + pow(self.vec[ix2] * fabs(x2_data[i2]), self.p) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + ix2 = x2_indices[i2 % len_x2_indices] + d = d + pow(self.vec[ix2] * fabs(x2_data[i2]), self.p) + i2 = i2 + 1 + else: + while i1 < x1_end: + ix1 = x1_indices[i1 % len_x1_indices] + d = d + pow(self.vec[ix1] * fabs(x1_data[i1]), self.p) + i1 = i1 + 1 + + return d + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + return pow( + self.rdist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + ), + 1 / self.p + ) #------------------------------------------------------------ # Mahalanobis Distance @@ -774,19 +1587,23 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): if VI.ndim != 2 or VI.shape[0] != VI.shape[1]: raise ValueError("V/VI must be square") - self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype={{DTYPE}}, order='C')) + self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype=DTYPE, order='C')) self.size = self.mat.shape[0] # we need vec as a work buffer - self.vec = np.zeros(self.size, dtype={{DTYPE}}) + self.vec = np.zeros(self.size, dtype=DTYPE) def _validate_data(self, X): if X.shape[1] != self.size: raise ValueError('Mahalanobis dist: size of V does not match') - cdef inline DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t rdist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t tmp, d = 0 cdef cnp.intp_t i, j @@ -801,14 +1618,18 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): d += tmp * self.vec[i] return d - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: return sqrt(self.rdist(x1, x2, size)) - cdef inline DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: + cdef inline DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: return sqrt(rdist) - cdef inline DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: + cdef inline DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: return dist * dist def rdist_to_dist(self, rdist): @@ -817,6 +1638,89 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): def dist_to_rdist(self, dist): return dist ** 2 + cdef inline DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t tmp, d = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + self.vec[ix1] = x1_data[i1] - x2_data[i2] + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + self.vec[ix1] = x1_data[i1] + i1 = i1 + 1 + else: + self.vec[ix2] = - x2_data[i2] + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + ix2 = x2_indices[i2 % len_x2_indices] + self.vec[ix2] = - x2_data[i2] + i2 = i2 + 1 + else: + while i1 < x1_end: + ix1 = x1_indices[i1 % len_x1_indices] + self.vec[ix1] = x1_data[i1] + i1 = i1 + 1 + + for i in range(size): + tmp = 0 + for j in range(size): + tmp += self.mat[i, j] * self.vec[j] + d += tmp * self.vec[i] + + return d + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + return sqrt( + self.rdist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + )) #------------------------------------------------------------ # Hamming Distance @@ -830,8 +1734,12 @@ cdef class HammingDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): .. math:: D(x, y) = \frac{1}{N} \sum_i \delta_{x_i, y_i} """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef int n_unequal = 0 cdef cnp.intp_t j for j in range(size): @@ -840,6 +1748,60 @@ cdef class HammingDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): return float(n_unequal) / size + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d += (x1_data[i1] != x2_data[i2]) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d += (x1_data[i1] != 0) + i1 = i1 + 1 + else: + d += (x2_data[i2] != 0) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + d += (x2_data[i2] != 0) + i2 = i2 + 1 + else: + while i1 < x1_end: + d += (x1_data[i1] != 0) + i1 = i1 + 1 + + d /= size + + return d + + #------------------------------------------------------------ # Canberra Distance # D(x, y) = sum[ abs(x_i - y_i) / (abs(x_i) + abs(y_i)) ] @@ -852,16 +1814,73 @@ cdef class CanberraDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): .. math:: D(x, y) = \sum_i \frac{|x_i - y_i|}{|x_i| + |y_i|} """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t denom, d = 0 cdef cnp.intp_t j for j in range(size): - denom = (fabs(x1[j]) + fabs(x2[j])) + denom = fabs(x1[j]) + fabs(x2[j]) if denom > 0: - d += (fabs(x1[j] - x2[j])) / denom + d += fabs(x1[j] - x2[j]) / denom return d + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t d = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + d += ( + fabs(x1_data[i1] - x2_data[i2]) / + (fabs(x1_data[i1]) + fabs(x2_data[i2])) + ) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + d += 1. + i1 = i1 + 1 + else: + d += 1. + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + d += 1. + i2 = i2 + 1 + else: + while i1 < x1_end: + d += 1. + i1 = i1 + 1 + + return d #------------------------------------------------------------ # Bray-Curtis Distance @@ -875,18 +1894,78 @@ cdef class BrayCurtisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): .. math:: D(x, y) = \frac{\sum_i |x_i - y_i|}{\sum_i(|x_i| + |y_i|)} """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef DTYPE_t num = 0, denom = 0 cdef cnp.intp_t j for j in range(size): - num += fabs(x1[j] - x2[j]) - denom += (fabs(x1[j]) + fabs(x2[j])) + num += fabs(x1[j] - x2[j]) + denom += fabs(x1[j]) + fabs(x2[j]) if denom > 0: return num / denom else: return 0.0 + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t num = 0.0 + DTYPE_t denom = 0.0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + num += fabs(x1_data[i1] - x2_data[i2]) + denom += fabs(x1_data[i1]) + fabs(x2_data[i2]) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + num += fabs(x1_data[i1]) + denom += fabs(x1_data[i1]) + i1 = i1 + 1 + else: + num += fabs(x2_data[i2]) + denom += fabs(x2_data[i2]) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + num += fabs(x1_data[i1]) + denom += fabs(x1_data[i1]) + i2 = i2 + 1 + else: + while i1 < x1_end: + num += fabs(x2_data[i2]) + denom += fabs(x2_data[i2]) + i1 = i1 + 1 + + return num / denom #------------------------------------------------------------ # Jaccard Distance (boolean) @@ -900,8 +1979,12 @@ cdef class JaccardDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = (N_TF + N_FT) / (N_TT + N_TF + N_FT) """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef int tf1, tf2, n_eq = 0, nnz = 0 cdef cnp.intp_t j for j in range(size): @@ -916,6 +1999,67 @@ cdef class JaccardDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): return 0 return (nnz - n_eq) * 1.0 / nnz + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_tt = 0, nnz = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + nnz += (tf1 or tf2) + n_tt += (tf1 and tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + nnz += tf1 + i1 = i1 + 1 + else: + nnz += tf2 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + nnz += tf2 + i2 = i2 + 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + nnz += tf1 + i1 = i1 + 1 + + # Based on https://github.com/scipy/scipy/pull/7373 + # When comparing two all-zero vectors, scipy>=1.2.0 jaccard metric + # was changed to return 0, instead of nan. + if nnz == 0: + return 0 + return (nnz - n_tt) * 1.0 / nnz #------------------------------------------------------------ # Matching Distance (boolean) @@ -929,8 +2073,12 @@ cdef class MatchingDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = (N_TF + N_FT) / N """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef int tf1, tf2, n_neq = 0 cdef cnp.intp_t j for j in range(size): @@ -939,6 +2087,58 @@ cdef class MatchingDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): n_neq += (tf1 != tf2) return n_neq * 1. / size + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_neq = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + if ix1 == ix2: + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + n_neq += (tf1 != tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + n_neq += (x1_data[i1] != 0) + i1 = i1 + 1 + else: + n_neq += (x2_data[i2] != 0) + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + n_neq += (x2_data[i2] != 0) + i2 = i2 + 1 + else: + while i1 < x1_end: + n_neq += (x1_data[i1] != 0) + i1 = i1 + 1 + + return n_neq * 1.0 / size #------------------------------------------------------------ # Dice Distance (boolean) @@ -953,16 +2153,77 @@ cdef class DiceDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = (N_TF + N_FT) / (2 * N_TT + N_TF + N_FT) """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - cdef int tf1, tf2, n_neq = 0, ntt = 0 + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + cdef int tf1, tf2, n_neq = 0, n_tt = 0 cdef cnp.intp_t j for j in range(size): tf1 = x1[j] != 0 tf2 = x2[j] != 0 - ntt += (tf1 and tf2) + n_tt += (tf1 and tf2) n_neq += (tf1 != tf2) - return n_neq / (2.0 * ntt + n_neq) + return n_neq / (2.0 * n_tt + n_neq) + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_tt = 0, n_neq = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + n_tt += (tf1 and tf2) + n_neq += (tf1 != tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + n_neq += tf1 + i1 = i1 + 1 + else: + n_neq += tf2 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + n_neq += tf2 + i2 = i2 + 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + n_neq += tf1 + i1 = i1 + 1 + + return n_neq / (2.0 * n_tt + n_neq) #------------------------------------------------------------ @@ -978,17 +2239,77 @@ cdef class KulsinskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = 1 - N_TT / (N + N_TF + N_FT) """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - cdef int tf1, tf2, ntt = 0, n_neq = 0 + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + cdef int tf1, tf2, n_tt = 0, n_neq = 0 cdef cnp.intp_t j for j in range(size): tf1 = x1[j] != 0 tf2 = x2[j] != 0 n_neq += (tf1 != tf2) - ntt += (tf1 and tf2) - return (n_neq - ntt + size) * 1.0 / (n_neq + size) + n_tt += (tf1 and tf2) + return (n_neq - n_tt + size) * 1.0 / (n_neq + size) + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_tt = 0, n_neq = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + n_tt += (tf1 and tf2) + n_neq += (tf1 != tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + n_neq += tf1 + i1 = i1 + 1 + else: + n_neq += tf2 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + n_neq += tf2 + i2 = i2 + 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + n_neq += tf1 + i1 = i1 + 1 + return (n_neq - n_tt + size) * 1.0 / (n_neq + size) #------------------------------------------------------------ # Rogers-Tanimoto Distance (boolean) @@ -1002,8 +2323,12 @@ cdef class RogersTanimotoDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = 2 (N_TF + N_FT) / (N + N_TF + N_FT) """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef int tf1, tf2, n_neq = 0 cdef cnp.intp_t j for j in range(size): @@ -1012,6 +2337,61 @@ cdef class RogersTanimotoDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): n_neq += (tf1 != tf2) return (2.0 * n_neq) / (size + n_neq) + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_neq = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + n_neq += (tf1 != tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + n_neq += tf1 + i1 = i1 + 1 + else: + n_neq += tf2 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + n_neq += tf2 + i2 = i2 + 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + n_neq += tf1 + i1 = i1 + 1 + + return (2.0 * n_neq) / (size + n_neq) #------------------------------------------------------------ # Russell-Rao Distance (boolean) @@ -1025,15 +2405,67 @@ cdef class RussellRaoDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = (N - N_TT) / N """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - cdef int tf1, tf2, ntt = 0 + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + cdef int tf1, tf2, n_tt = 0 cdef cnp.intp_t j for j in range(size): tf1 = x1[j] != 0 tf2 = x2[j] != 0 - ntt += (tf1 and tf2) - return (size - ntt) * 1. / size + n_tt += (tf1 and tf2) + return (size - n_tt) * 1. / size + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_tt = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + n_tt += (tf1 and tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + i1 = i1 + 1 + else: + i2 = i2 + 1 + + # We don't need to go through all the longuest + # vector because tf1 or tf2 will be false + # and thus n_tt won't be increased. + + return (size - n_tt) * 1. / size + #------------------------------------------------------------ @@ -1048,8 +2480,12 @@ cdef class SokalMichenerDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = 2 (N_TF + N_FT) / (N + N_TF + N_FT) """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: cdef int tf1, tf2, n_neq = 0 cdef cnp.intp_t j for j in range(size): @@ -1058,6 +2494,61 @@ cdef class SokalMichenerDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): n_neq += (tf1 != tf2) return (2.0 * n_neq) / (size + n_neq) + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_neq = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + n_neq += (tf1 != tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + n_neq += tf1 + i1 = i1 + 1 + else: + n_neq += tf2 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + n_neq += tf2 + i2 = i2 + 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + n_neq += tf1 + i1 = i1 + 1 + + return (2.0 * n_neq) / (size + n_neq) #------------------------------------------------------------ # Sokal-Sneath Distance (boolean) @@ -1071,16 +2562,77 @@ cdef class SokalSneathDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): D(x, y) = (N_TF + N_FT) / (N_TT / 2 + N_FT + N_TF) """ - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - cdef int tf1, tf2, ntt = 0, n_neq = 0 + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + cdef int tf1, tf2, n_tt = 0, n_neq = 0 cdef cnp.intp_t j for j in range(size): tf1 = x1[j] != 0 tf2 = x2[j] != 0 n_neq += (tf1 != tf2) - ntt += (tf1 and tf2) - return n_neq / (0.5 * ntt + n_neq) + n_tt += (tf1 and tf2) + return n_neq / (0.5 * n_tt + n_neq) + + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + ITYPE_t tf1, tf2, n_tt = 0, n_neq = 0 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + tf1 = x1_data[i1] != 0 + tf2 = x2_data[i2] != 0 + + if ix1 == ix2: + n_tt += (tf1 and tf2) + n_neq += (tf1 != tf2) + i1 = i1 + 1 + i2 = i2 + 1 + elif ix1 < ix2: + n_neq += tf1 + i1 = i1 + 1 + else: + n_neq += tf2 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + tf2 = x2_data[i2] != 0 + n_neq += tf2 + i2 = i2 + 1 + else: + while i1 < x1_end: + tf1 = x1_data[i1] != 0 + n_neq += tf1 + i1 = i1 + 1 + + return n_neq / (0.5 * n_tt + n_neq) #------------------------------------------------------------ @@ -1104,21 +2656,27 @@ cdef class HaversineDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): raise ValueError("Haversine distance only valid " "in 2 dimensions") - cdef inline DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: - cdef DTYPE_t sin_0 = sin(0.5 * (x1[0] - x2[0])) - cdef DTYPE_t sin_1 = sin(0.5 * (x1[1] - x2[1])) + cdef inline DTYPE_t rdist(self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: + cdef DTYPE_t sin_0 = sin(0.5 * ((x1[0]) - (x2[0]))) + cdef DTYPE_t sin_1 = sin(0.5 * ((x1[1]) - (x2[1]))) return (sin_0 * sin_0 + cos(x1[0]) * cos(x2[0]) * sin_1 * sin_1) - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist(self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: return 2 * asin(sqrt(self.rdist(x1, x2, size))) - cdef inline DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1: + cdef inline DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1: return 2 * asin(sqrt(rdist)) - cdef inline DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1: - cdef DTYPE_t tmp = sin(0.5 * dist) + cdef inline DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1: + cdef DTYPE_t tmp = sin(0.5 * dist) return tmp * tmp def rdist_to_dist(self, rdist): @@ -1128,6 +2686,105 @@ cdef class HaversineDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): tmp = np.sin(0.5 * dist) return tmp * tmp + cdef inline DTYPE_t dist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + return 2 * asin(sqrt(self.rdist_csr( + x1_data, + x1_indices, + x2_data, + x2_indices, + x1_start, + x1_end, + x2_start, + x2_end, + size, + ))) + + cdef inline DTYPE_t rdist_csr( + self, + const {{INPUT_DTYPE_t}}[:] x1_data, + const SPARSE_INDEX_TYPE_t[:] x1_indices, + const {{INPUT_DTYPE_t}}[:] x2_data, + const SPARSE_INDEX_TYPE_t[:] x2_indices, + const SPARSE_INDEX_TYPE_t x1_start, + const SPARSE_INDEX_TYPE_t x1_end, + const SPARSE_INDEX_TYPE_t x2_start, + const SPARSE_INDEX_TYPE_t x2_end, + const ITYPE_t size, + ) nogil except -1: + + cdef: + cnp.npy_intp ix1, ix2 + cnp.npy_intp i1 = x1_start + cnp.npy_intp i2 = x2_start + cnp.npy_intp len_x1_indices = x1_indices.shape[0] + cnp.npy_intp len_x2_indices = x2_indices.shape[0] + + DTYPE_t x1_0 = 0 + DTYPE_t x1_1 = 0 + DTYPE_t x2_0 = 0 + DTYPE_t x2_1 = 0 + DTYPE_t sin_0 + DTYPE_t sin_1 + + while i1 < x1_end and i2 < x2_end: + # Use the modulo-trick to implement support for CSR × dense array + # with the CSR × CSR routine. See _pairwise_sparse_dense for more + # details. + ix1 = x1_indices[i1 % len_x1_indices] + ix2 = x2_indices[i2 % len_x2_indices] + + # Find the components in the 2D vectors to work with + x1_component = ix1 if (x1_start == 0) else ix1 % x1_start + x2_component = ix2 if (x2_start == 0) else ix2 % x2_start + + if x1_component == 0: + x1_0 = x1_data[i1] + else: + x1_1 = x1_data[i1] + + if x2_component == 0: + x2_0 = x2_data[i2] + else: + x2_1 = x2_data[i2] + + i1 = i1 + 1 + i2 = i2 + 1 + + if i1 == x1_end: + while i2 < x2_end: + ix2 = x2_indices[i2 % len_x2_indices] + x2_component = ix2 if (x2_start == 0) else ix2 % x2_start + if x2_component == 0: + x2_0 = x2_data[i2] + else: + x2_1 = x2_data[i2] + i2 = i2 + 1 + else: + while i1 < x1_end: + ix1 = x1_indices[i1 % len_x1_indices] + x1_component = ix1 if (x1_start == 0) else ix1 % x1_start + if x1_component == 0: + x1_0 = x1_data[i1] + else: + x1_1 = x1_data[i1] + i1 = i1 + 1 + + sin_0 = sin(0.5 * (x1_0 - x2_0)) + sin_1 = sin(0.5 * (x1_1 - x2_1)) + + return (sin_0 * sin_0 + cos(x1_0) * cos(x2_0) * sin_1 * sin_1) + #------------------------------------------------------------ # User-defined distance # @@ -1150,12 +2807,20 @@ cdef class PyFuncDistance{{name_suffix}}(DistanceMetric{{name_suffix}}): # allowed in cython >= 0.26 since it is a redundant GIL acquisition. The # only way to be back compatible is to inherit `dist` from the base class # without GIL and called an inline `_dist` which acquire GIL. - cdef inline DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) nogil except -1: + cdef inline DTYPE_t dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) nogil except -1: return self._dist(x1, x2, size) - cdef inline DTYPE_t _dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2, - ITYPE_t size) except -1 with gil: + cdef inline DTYPE_t _dist( + self, + const {{INPUT_DTYPE_t}}* x1, + const {{INPUT_DTYPE_t}}* x2, + ITYPE_t size, + ) except -1 with gil: cdef cnp.ndarray x1arr cdef cnp.ndarray x2arr x1arr = _buffer_to_ndarray{{name_suffix}}(x1, size) diff --git a/sklearn/metrics/tests/test_dist_metrics.py b/sklearn/metrics/tests/test_dist_metrics.py index 4cc8b945ffdab..e11be4dab3e20 100644 --- a/sklearn/metrics/tests/test_dist_metrics.py +++ b/sklearn/metrics/tests/test_dist_metrics.py @@ -3,8 +3,6 @@ import copy import numpy as np -from sklearn.utils._testing import assert_allclose - import pytest import scipy.sparse as sp @@ -18,7 +16,7 @@ ) from sklearn.utils import check_random_state -from sklearn.utils._testing import create_memmap_backed_data +from sklearn.utils._testing import assert_allclose, create_memmap_backed_data from sklearn.utils.fixes import sp_version, parse_version @@ -38,8 +36,8 @@ def dist_func(x1, x2, p): [X_mmap, Y_mmap] = create_memmap_backed_data([X64, Y64]) # make boolean arrays: ones and zeros -X_bool = X64.round(0) -Y_bool = Y64.round(0) +X_bool = (X64 < 0.3).astype(np.float64) # quite sparse +Y_bool = (Y64 < 0.7).astype(np.float64) # not too sparse [X_bool_mmap, Y_bool_mmap] = create_memmap_backed_data([X_bool, Y_bool]) @@ -83,14 +81,17 @@ def test_cdist(metric_param_grid, X, Y): ) metric, param_grid = metric_param_grid keys = param_grid.keys() + X_csr, Y_csr = sp.csr_matrix(X), sp.csr_matrix(Y) for vals in itertools.product(*param_grid.values()): kwargs = dict(zip(keys, vals)) - if metric == "mahalanobis": - # See: https://github.com/scipy/scipy/issues/13861 - # Possibly caused by: https://github.com/joblib/joblib/issues/563 - pytest.xfail( - "scipy#13861: cdist with 'mahalanobis' fails on joblib memmap data" - ) + rtol_dict = {} + if metric == "mahalanobis" and X.dtype == np.float32: + # Computation of mahalanobis differs between + # the scipy and scikit-learn implementation. + # Hence, we increase the relative tolerance. + # TODO: Inspect slight numerical discrepancy + # with scipy + rtol_dict = {"rtol": 1e-6} if metric == "wminkowski": # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 @@ -103,8 +104,24 @@ def test_cdist(metric_param_grid, X, Y): D_scipy_cdist = cdist(X, Y, metric, **kwargs) dm = DistanceMetricInterface.get_metric(metric, **kwargs) + + # DistanceMetric.pairwise must be consistent for all + # combinations of formats in {sparse, dense}. D_sklearn = dm.pairwise(X, Y) - assert_allclose(D_sklearn, D_scipy_cdist) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist, **rtol_dict) + + D_sklearn = dm.pairwise(X_csr, Y_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist, **rtol_dict) + + D_sklearn = dm.pairwise(X_csr, Y) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist, **rtol_dict) + + D_sklearn = dm.pairwise(X, Y_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist, **rtol_dict) @pytest.mark.parametrize("metric", BOOL_METRICS) @@ -112,28 +129,56 @@ def test_cdist(metric_param_grid, X, Y): "X_bool, Y_bool", [(X_bool, Y_bool), (X_bool_mmap, Y_bool_mmap)] ) def test_cdist_bool_metric(metric, X_bool, Y_bool): - D_true = cdist(X_bool, Y_bool, metric) + D_scipy_cdist = cdist(X_bool, Y_bool, metric) + dm = DistanceMetric.get_metric(metric) - D12 = dm.pairwise(X_bool, Y_bool) - assert_allclose(D12, D_true) + D_sklearn = dm.pairwise(X_bool, Y_bool) + assert_allclose(D_sklearn, D_scipy_cdist) + + # DistanceMetric.pairwise must be consistent + # on all combinations of format in {sparse, dense}². + X_bool_csr, Y_bool_csr = sp.csr_matrix(X_bool), sp.csr_matrix(Y_bool) + + D_sklearn = dm.pairwise(X_bool, Y_bool) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist) + + D_sklearn = dm.pairwise(X_bool_csr, Y_bool_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist) + + D_sklearn = dm.pairwise(X_bool, Y_bool_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist) + + D_sklearn = dm.pairwise(X_bool_csr, Y_bool) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_cdist) # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") @pytest.mark.parametrize("metric_param_grid", METRICS_DEFAULT_PARAMS) -@pytest.mark.parametrize("X, Y", [(X64, Y64), (X32, Y32), (X_mmap, Y_mmap)]) -def test_pdist(metric_param_grid, X, Y): +@pytest.mark.parametrize("X", [X64, X32, X_mmap]) +def test_pdist(metric_param_grid, X): DistanceMetricInterface = ( - DistanceMetric if X.dtype == Y.dtype == np.float64 else DistanceMetric32 + DistanceMetric if X.dtype == np.float64 else DistanceMetric32 ) metric, param_grid = metric_param_grid keys = param_grid.keys() + X_csr = sp.csr_matrix(X) for vals in itertools.product(*param_grid.values()): kwargs = dict(zip(keys, vals)) - if metric == "mahalanobis": - # See: https://github.com/scipy/scipy/issues/13861 - pytest.xfail("scipy#13861: pdist with 'mahalanobis' fails onmemmap data") - elif metric == "wminkowski": + rtol_dict = {} + if metric == "mahalanobis" and X.dtype == np.float32: + # Computation of mahalanobis differs between + # the scipy and scikit-learn implementation. + # Hence, we increase the relative tolerance. + # TODO: Inspect slight numerical discrepancy + # with scipy + rtol_dict = {"rtol": 1e-6} + + if metric == "wminkowski": if sp_version >= parse_version("1.8.0"): pytest.skip("wminkowski will be removed in SciPy 1.8.0") @@ -142,23 +187,37 @@ def test_pdist(metric_param_grid, X, Y): if sp_version >= parse_version("1.6.0"): ExceptionToAssert = DeprecationWarning with pytest.warns(ExceptionToAssert): - D_true = cdist(X, X, metric, **kwargs) + D_scipy_pdist = cdist(X, X, metric, **kwargs) else: - D_true = cdist(X, X, metric, **kwargs) + D_scipy_pdist = cdist(X, X, metric, **kwargs) dm = DistanceMetricInterface.get_metric(metric, **kwargs) - D12 = dm.pairwise(X) - assert_allclose(D12, D_true) + D_sklearn = dm.pairwise(X) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_scipy_pdist, **rtol_dict) + + D_sklearn_csr = dm.pairwise(X_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn_csr, D_scipy_pdist, **rtol_dict) + + D_sklearn_csr = dm.pairwise(X_csr, X_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn_csr, D_scipy_pdist, **rtol_dict) # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") @pytest.mark.parametrize("metric_param_grid", METRICS_DEFAULT_PARAMS) def test_distance_metrics_dtype_consistency(metric_param_grid): - # DistanceMetric must return similar distances for - # both 64bit and 32bit data. + # DistanceMetric must return similar distances for both float32 and float64 + # input data. metric, param_grid = metric_param_grid keys = param_grid.keys() + + # Choose rtol to make sure that this test is robust to changes in the random + # seed in the module-level test data generation code. + rtol = 1e-5 + for vals in itertools.product(*param_grid.values()): kwargs = dict(zip(keys, vals)) dm64 = DistanceMetric.get_metric(metric, **kwargs) @@ -166,25 +225,34 @@ def test_distance_metrics_dtype_consistency(metric_param_grid): D64 = dm64.pairwise(X64) D32 = dm32.pairwise(X32) - assert_allclose(D64, D32) + + # Both results are np.float64 dtype because the accumulation accross + # features is done in float64. However the input data and the element + # wise arithmetic operations are done in float32 so we can expect a + # small discrepancy. + assert D64.dtype == D32.dtype == np.float64 + + # assert_allclose introspects the dtype of the input arrays to decide + # which rtol value to use by default but in this case we know that D32 + # is not computed with the same precision so we set rtol manually. + assert_allclose(D64, D32, rtol=rtol) D64 = dm64.pairwise(X64, Y64) D32 = dm32.pairwise(X32, Y32) - assert_allclose(D64, D32) + assert_allclose(D64, D32, rtol=rtol) @pytest.mark.parametrize("metric", BOOL_METRICS) @pytest.mark.parametrize("X_bool", [X_bool, X_bool_mmap]) def test_pdist_bool_metrics(metric, X_bool): - D_true = cdist(X_bool, X_bool, metric) + D_scipy_pdist = cdist(X_bool, X_bool, metric) dm = DistanceMetric.get_metric(metric) - D12 = dm.pairwise(X_bool) - # Based on https://github.com/scipy/scipy/pull/7373 - # When comparing two all-zero vectors, scipy>=1.2.0 jaccard metric - # was changed to return 0, instead of nan. - if metric == "jaccard" and sp_version < parse_version("1.2.0"): - D_true[np.isnan(D_true)] = 0 - assert_allclose(D12, D_true) + D_sklearn = dm.pairwise(X_bool) + assert_allclose(D_sklearn, D_scipy_pdist) + + X_bool_csr = sp.csr_matrix(X_bool) + D_sklearn = dm.pairwise(X_bool_csr) + assert_allclose(D_sklearn, D_scipy_pdist) # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @@ -224,7 +292,20 @@ def test_pickle_bool_metrics(metric, X_bool): assert_allclose(D1, D2) -def test_haversine_metric(): +@pytest.mark.parametrize("X, Y", [(X64, Y64), (X32, Y32), (X_mmap, Y_mmap)]) +def test_haversine_metric(X, Y): + DistanceMetricInterface = ( + DistanceMetric if X.dtype == np.float64 else DistanceMetric32 + ) + + # The Haversine DistanceMetric only works on 2 features. + X = np.asarray(X[:, :2]) + Y = np.asarray(Y[:, :2]) + + X_csr, Y_csr = sp.csr_matrix(X), sp.csr_matrix(Y) + + # Haversine is not supported by scipy.special.distance.{cdist,pdist} + # So we reimplement it to have a reference. def haversine_slow(x1, x2): return 2 * np.arcsin( np.sqrt( @@ -233,18 +314,31 @@ def haversine_slow(x1, x2): ) ) - X = np.random.random((10, 2)) + D_reference = np.zeros((X_csr.shape[0], Y_csr.shape[0])) + for i, xi in enumerate(X): + for j, yj in enumerate(Y): + D_reference[i, j] = haversine_slow(xi, yj) - haversine = DistanceMetric.get_metric("haversine") + haversine = DistanceMetricInterface.get_metric("haversine") - D1 = haversine.pairwise(X) - D2 = np.zeros_like(D1) - for i, x1 in enumerate(X): - for j, x2 in enumerate(X): - D2[i, j] = haversine_slow(x1, x2) + D_sklearn = haversine.pairwise(X, Y) + assert_allclose( + haversine.dist_to_rdist(D_sklearn), np.sin(0.5 * D_reference) ** 2, rtol=1e-6 + ) - assert_allclose(D1, D2) - assert_allclose(haversine.dist_to_rdist(D1), np.sin(0.5 * D2) ** 2) + assert_allclose(D_sklearn, D_reference) + + D_sklearn = haversine.pairwise(X_csr, Y_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_reference) + + D_sklearn = haversine.pairwise(X_csr, Y) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_reference) + + D_sklearn = haversine.pairwise(X, Y_csr) + assert D_sklearn.flags.c_contiguous + assert_allclose(D_sklearn, D_reference) def test_pyfunc_metric(): diff --git a/sklearn/utils/_typedefs.pxd b/sklearn/utils/_typedefs.pxd index ee0c8ca3b57e9..a6e390705496b 100644 --- a/sklearn/utils/_typedefs.pxd +++ b/sklearn/utils/_typedefs.pxd @@ -15,3 +15,14 @@ cdef enum: ctypedef cnp.intp_t ITYPE_t # WARNING: should match ITYPE in typedefs.pyx ctypedef cnp.int32_t INT32TYPE_t # WARNING: should match INT32TYPE in typedefs.pyx ctypedef cnp.int64_t INT64TYPE_t # WARNING: should match INT32TYPE in typedefs.pyx + +# scipy matrices indices dtype (namely for indptr and indices arrays) +# +# Note that indices might need to be represented as cnp.int64_t. +# Currently, we use Cython classes which do not handle fused types +# so we hardcode this type to cnp.int32_t, supporting all but edge +# cases. +# +# TODO: support cnp.int64_t for this case +# See: https://github.com/scikit-learn/scikit-learn/issues/23653 +ctypedef cnp.int32_t SPARSE_INDEX_TYPE_t diff --git a/sklearn/utils/_typedefs.pyx b/sklearn/utils/_typedefs.pyx index 09e5a6a44944a..49d0e46101b4f 100644 --- a/sklearn/utils/_typedefs.pyx +++ b/sklearn/utils/_typedefs.pyx @@ -19,6 +19,9 @@ INT64TYPE = np.int64 # WARNING: this should match INT64TYPE_t in typedefs.pxd #DTYPE = np.asarray(ddummy_view).dtype DTYPE = np.float64 # WARNING: this should match DTYPE_t in typedefs.pxd +# WARNING: this must match SPARSE_INDEX_TYPE_t in typedefs.pxd +SPARSE_INDEX_TYPE = np.int32 + # some handy constants cdef DTYPE_t INF = np.inf cdef DTYPE_t PI = np.pi