Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT Introduce FastEuclideanPairwiseArgKmin #22065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

302 changes: 274 additions & 28 deletions sklearn/metrics/_pairwise_distances_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,24 @@
cimport numpy as np
import numpy as np
import warnings
import scipy.sparse

from .. import get_config
from libc.stdlib cimport free, malloc
from libc.float cimport DBL_MAX
from cython cimport final
from cython.parallel cimport parallel, prange

from ._dist_metrics cimport DatasetsPair
from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair
from ..utils._cython_blas cimport (
BLAS_Order,
BLAS_Trans,
ColMajor,
NoTrans,
RowMajor,
Trans,
_dot,
_gemm,
)
from ..utils._heap cimport simultaneous_sort, heap_push
from ..utils._openmp_helpers cimport _openmp_thread_num
from ..utils._typedefs cimport ITYPE_t, DTYPE_t
Expand All @@ -33,19 +42,45 @@ from numbers import Integral
from typing import List
from scipy.sparse import issparse
from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING
from ..utils import check_scalar
from ..utils import check_scalar, _in_unstable_openblas_configuration
from ..utils.fixes import threadpool_limits
from ..utils._openmp_helpers import _openmp_effective_n_threads
from ..utils._typedefs import ITYPE, DTYPE


np.import_array()

cpdef DTYPE_t[::1] _sqeuclidean_row_norms(
const DTYPE_t[:, ::1] X,
ITYPE_t num_threads,
):
"""Compute the squared euclidean norm of the rows of X in parallel.

This is faster than using np.einsum("ij, ij->i") even when using a single thread.
"""
cdef:
# Casting for X to remove the const qualifier is needed because APIs
# exposed via scipy.linalg.cython_blas aren't reflecting the arguments'
# const qualifier.
# See: https://github.com/scipy/scipy/issues/14262
DTYPE_t * X_ptr = <DTYPE_t *> &X[0, 0]
ITYPE_t idx = 0
ITYPE_t n = X.shape[0]
ITYPE_t d = X.shape[1]
DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE)

for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads):
squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1)

return squared_row_norms



cdef class PairwiseDistancesReduction:
"""Abstract base class for pairwise distance computation & reduction.

Subclasses of this class compute pairwise distances between a set of
row vectors of X and another set of row vectors pf Y and apply a reduction on top.
row vectors of X and another set of row vectors of Y and apply a reduction on top.
The reduction takes a matrix of pairwise distances between rows of X and Y
as input and outputs an aggregate data-structure for each row of X.
The aggregate values are typically smaller than the number of rows in Y,
Expand Down Expand Up @@ -180,14 +215,12 @@ cdef class PairwiseDistancesReduction:
not issparse(Y) and Y.dtype == np.float64 and
metric in cls.valid_metrics())

def __cinit__(
def __init__(
self,
DatasetsPair datasets_pair,
chunk_size=None,
n_threads=None,
strategy=None,
*args,
**kwargs,
):
cdef:
ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks
Expand Down Expand Up @@ -611,13 +644,32 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
# For future work, this might can be an entrypoint to specialise operations
# for various back-end and/or hardware and/or datatypes, and/or fused
# {sparse, dense}-datasetspair etc.

pda = PairwiseDistancesArgKmin(
datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs),
k=k,
chunk_size=chunk_size,
strategy=strategy,
)
if (
metric in ("euclidean", "sqeuclidean")
and not issparse(X)
and not issparse(Y)
):
# Specialized implementation with improved arithmetic intensity
# and vector instructions (SIMD) by processing several vectors
# at time to leverage a call to the BLAS GEMM routine as explained
# in more details in the docstring.
use_squared_distances = metric == "sqeuclidean"
pda = FastEuclideanPairwiseDistancesArgKmin(
X=X, Y=Y, k=k,
use_squared_distances=use_squared_distances,
chunk_size=chunk_size,
strategy=strategy,
metric_kwargs=metric_kwargs,
)
else:
# Fall back on a generic implementation that handles most scipy
# metrics by computing the distances between 2 vectors at a time.
pda = PairwiseDistancesArgKmin(
datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs),
k=k,
chunk_size=chunk_size,
strategy=strategy,
)

# Limit the number of threads in second level of nested parallelism for BLAS
# to avoid threads over-subscription (in GEMM for instance).
Expand All @@ -629,15 +681,22 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):

return pda._finalize_results(return_distance)

def __cinit__(
def __init__(
self,
DatasetsPair datasets_pair,
chunk_size=None,
n_threads=None,
strategy=None,
*args,
**kwargs,
):
ITYPE_t k=1,
):
super().__init__(
datasets_pair=datasets_pair,
chunk_size=chunk_size,
n_threads=n_threads,
strategy=strategy,
)
self.k = check_scalar(k, "k", Integral, min_val=1)

# Allocating pointers to datastructures but not the datastructures themselves.
# There are as many pointers as effective threads.
#
Expand All @@ -654,16 +713,6 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
sizeof(ITYPE_t *) * self.chunks_n_threads
)

def __init__(
self,
DatasetsPair datasets_pair,
chunk_size=None,
n_threads=None,
strategy=None,
ITYPE_t k=1,
):
self.k = check_scalar(k, "k", Integral, min_val=1)

# Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`.
self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE)
self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE)
Expand Down Expand Up @@ -837,3 +886,200 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
return np.asarray(self.argkmin_indices), np.asarray(self.argkmin_distances)

return np.asarray(self.argkmin_indices)


cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
"""Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance.

The full pairwise squared distances matrix is computed as follows:

||X - Y||² = ||X||² - 2 X.Y^T + ||Y||²

The middle term gets computed efficiently bellow using BLAS Level 3 GEMM.

Notes
-----
This implementation has a superior arithmetic intensity and hence
better running time when the alternative is IO bound, but it can suffer
from numerical instability caused by catastrophic cancellation potentially
introduced by the subtraction in the arithmetic expression above.
"""

cdef:
const DTYPE_t[:, ::1] X
const DTYPE_t[:, ::1] Y
const DTYPE_t[::1] X_norm_squared
const DTYPE_t[::1] Y_norm_squared

# Buffers for GEMM
DTYPE_t ** dist_middle_terms_chunks
bint use_squared_distances

@classmethod
def is_usable_for(cls, X, Y, metric) -> bool:
return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and
not _in_unstable_openblas_configuration())

def __init__(
self,
X,
Y,
ITYPE_t k,
bint use_squared_distances=False,
chunk_size=None,
n_threads=None,
strategy=None,
metric_kwargs=None,
):
if metric_kwargs is not None and len(metric_kwargs) > 0:
warnings.warn(
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't"
f"usable for this case ({self.__class__.__name__}) and will be ignored.",
UserWarning,
stacklevel=3,
)

super().__init__(
# The datasets pair here is used for exact distances computations
datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"),
chunk_size=chunk_size,
n_threads=n_threads,
strategy=strategy,
k=k,
)
# X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair
cdef:
DenseDenseDatasetsPair datasets_pair = <DenseDenseDatasetsPair> self.datasets_pair
self.X, self.Y = datasets_pair.X, datasets_pair.Y

if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared")
else:
self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads)

# Do not recompute norms if datasets are identical.
self.X_norm_squared = (
self.Y_norm_squared if X is Y else
_sqeuclidean_row_norms(self.X, self.effective_n_threads)
)
self.use_squared_distances = use_squared_distances

# Temporary datastructures used in threads
self.dist_middle_terms_chunks = <DTYPE_t **> malloc(
sizeof(DTYPE_t *) * self.chunks_n_threads
)

def __dealloc__(self):
if self.dist_middle_terms_chunks is not NULL:
free(self.dist_middle_terms_chunks)

@final
cdef void compute_exact_distances(self) nogil:
if not self.use_squared_distances:
PairwiseDistancesArgKmin.compute_exact_distances(self)

@final
cdef void _parallel_on_X_parallel_init(
self,
ITYPE_t thread_num,
) nogil:
PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num)

# Temporary buffer for the `-2 * X_c @ Y_c.T` term
self.dist_middle_terms_chunks[thread_num] = <DTYPE_t *> malloc(
self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t)
)

@final
cdef void _parallel_on_X_parallel_finalize(
self,
ITYPE_t thread_num
) nogil:
PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self, thread_num)
free(self.dist_middle_terms_chunks[thread_num])

@final
cdef void _parallel_on_Y_init(
self,
) nogil:
cdef ITYPE_t thread_num
PairwiseDistancesArgKmin._parallel_on_Y_init(self)

for thread_num in range(self.chunks_n_threads):
# Temporary buffer for the `-2 * X_c @ Y_c.T` term
self.dist_middle_terms_chunks[thread_num] = <DTYPE_t *> malloc(
self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t)
)

@final
cdef void _parallel_on_Y_finalize(
self,
) nogil:
cdef ITYPE_t thread_num
PairwiseDistancesArgKmin._parallel_on_Y_finalize(self)

for thread_num in range(self.chunks_n_threads):
free(self.dist_middle_terms_chunks[thread_num])

@final
cdef void _compute_and_reduce_distances_on_chunks(
self,
ITYPE_t X_start,
ITYPE_t X_end,
ITYPE_t Y_start,
ITYPE_t Y_end,
ITYPE_t thread_num,
) nogil:
cdef:
ITYPE_t i, j

const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :]
const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num]
DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num]

# Careful: LDA, LDB and LDC are given for F-ordered arrays
# in BLAS documentations, for instance:
# https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa
#
# Here, we use their counterpart values to work with C-ordered arrays.
BLAS_Order order = RowMajor
BLAS_Trans ta = NoTrans
BLAS_Trans tb = Trans
ITYPE_t m = X_c.shape[0]
ITYPE_t n = Y_c.shape[0]
ITYPE_t K = X_c.shape[1]
DTYPE_t alpha = - 2.
# Casting for A and B to remove the const is needed because APIs exposed via
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
# See: https://github.com/scipy/scipy/issues/14262
DTYPE_t * A = <DTYPE_t*> & X_c[0, 0]
ITYPE_t lda = X_c.shape[1]
DTYPE_t * B = <DTYPE_t*> & Y_c[0, 0]
ITYPE_t ldb = X_c.shape[1]
DTYPE_t beta = 0.
ITYPE_t ldc = Y_c.shape[0]

# dist_middle_terms = `-2 * X_c @ Y_c.T`
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)

# Pushing the distance and their associated indices on heaps
# which keep tracks of the argkmin.
for i in range(X_c.shape[0]):
for j in range(Y_c.shape[0]):
heap_push(
heaps_r_distances + i * self.k,
heaps_indices + i * self.k,
self.k,
# Using the squared euclidean distance as the rank-preserving distance:
#
# ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
#
(
self.X_norm_squared[i + X_start] +
dist_middle_terms[i * Y_c.shape[0] + j] +
self.Y_norm_squared[j + Y_start]
),
j + Y_start,
)
Loading