From 0532aab213ee3fc8a5e88af1f5bde5454035c6dd Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 5 Jan 2022 15:19:56 +0100 Subject: [PATCH 01/30] MAINT Introduce Pairwise Distances Reductions private submodule (#22064) Co-authored-by: Thomas J. Fan Co-authored-by: Christian Lorentzen Co-authored-by: Olivier Grisel --- sklearn/_config.py | 30 +- sklearn/metrics/_dist_metrics.pxd | 21 + sklearn/metrics/_dist_metrics.pyx | 194 +++- .../metrics/_pairwise_distances_reduction.pyx | 839 ++++++++++++++++++ sklearn/metrics/setup.py | 7 + .../test_pairwise_distances_reduction.py | 357 ++++++++ sklearn/tests/test_config.py | 3 + sklearn/utils/_openmp_helpers.pxd | 6 + sklearn/utils/_openmp_helpers.pyx | 15 +- 9 files changed, 1458 insertions(+), 14 deletions(-) create mode 100644 sklearn/metrics/_pairwise_distances_reduction.pyx create mode 100644 sklearn/metrics/tests/test_pairwise_distances_reduction.py create mode 100644 sklearn/utils/_openmp_helpers.pxd diff --git a/sklearn/_config.py b/sklearn/_config.py index c41c180012056..d6a02737f640d 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -9,6 +9,9 @@ "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": int( + os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256) + ), } _threadlocal = threading.local() @@ -40,7 +43,11 @@ def get_config(): def set_config( - assume_finite=None, working_memory=None, print_changed_only=None, display=None + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + pairwise_dist_chunk_size=None, ): """Set global scikit-learn configuration @@ -80,6 +87,12 @@ def set_config( .. versionadded:: 0.23 + pairwise_dist_chunk_size : int, default=None + The number of vectors per chunk for PairwiseDistancesReduction. + Default is 256 (suitable for most of modern laptops' caches and architectures). + + .. versionadded:: 1.1 + See Also -------- config_context : Context manager for global scikit-learn configuration. @@ -95,11 +108,18 @@ def set_config( local_config["print_changed_only"] = print_changed_only if display is not None: local_config["display"] = display + if pairwise_dist_chunk_size is not None: + local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size @contextmanager def config_context( - *, assume_finite=None, working_memory=None, print_changed_only=None, display=None + *, + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + pairwise_dist_chunk_size=None, ): """Context manager for global scikit-learn configuration. @@ -138,6 +158,12 @@ def config_context( .. versionadded:: 0.23 + pairwise_dist_chunk_size : int, default=None + The number of vectors per chunk for PairwiseDistancesReduction. + Default is 256 (suitable for most of modern laptops' caches and architectures). + + .. versionadded:: 1.1 + Yields ------ None. diff --git a/sklearn/metrics/_dist_metrics.pxd b/sklearn/metrics/_dist_metrics.pxd index 611f6759e2c8b..e7c2f2ea2f926 100644 --- a/sklearn/metrics/_dist_metrics.pxd +++ b/sklearn/metrics/_dist_metrics.pxd @@ -64,3 +64,24 @@ cdef class DistanceMetric: cdef DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1 cdef DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1 + + +###################################################################### +# DatasetsPair base class +cdef class DatasetsPair: + cdef DistanceMetric distance_metric + + cdef ITYPE_t n_samples_X(self) nogil + + cdef ITYPE_t n_samples_Y(self) nogil + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil + + +cdef class DenseDenseDatasetsPair(DatasetsPair): + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + ITYPE_t d diff --git a/sklearn/metrics/_dist_metrics.pyx b/sklearn/metrics/_dist_metrics.pyx index f7d22c1badfa2..6cf93baeca925 100644 --- a/sklearn/metrics/_dist_metrics.pyx +++ b/sklearn/metrics/_dist_metrics.pyx @@ -4,6 +4,8 @@ import numpy as np cimport numpy as np +from cython cimport final + np.import_array() # required in order to use C-API @@ -23,10 +25,10 @@ cdef inline np.ndarray _buffer_to_ndarray(const DTYPE_t* x, np.npy_intp n): return PyArray_SimpleNewFromData(1, &n, DTYPECODE, x) -# some handy constants from libc.math cimport fabs, sqrt, exp, pow, cos, sin, asin cdef DTYPE_t INF = np.inf +from scipy.sparse import csr_matrix, issparse from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE from ..utils._typedefs import DTYPE, ITYPE from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper @@ -67,6 +69,16 @@ METRIC_MAPPING = {'euclidean': EuclideanDistance, 'haversine': HaversineDistance, 'pyfunc': PyFuncDistance} +BOOL_METRICS = [ + "matching", + "jaccard", + "dice", + "kulsinski", + "rogerstanimoto", + "russellrao", + "sokalmichener", + "sokalsneath", +] def get_valid_metric_ids(L): """Given an iterable of metric class names or class identifiers, @@ -195,8 +207,8 @@ cdef class DistanceMetric: """ def __cinit__(self): self.p = 2 - self.vec = np.zeros(1, dtype=DTYPE, order='c') - self.mat = np.zeros((1, 1), dtype=DTYPE, order='c') + self.vec = np.zeros(1, dtype=DTYPE, order='C') + self.mat = np.zeros((1, 1), dtype=DTYPE, order='C') self.size = 1 def __reduce__(self): @@ -306,8 +318,9 @@ cdef class DistanceMetric: This can optionally be overridden in a base class. The rank-preserving surrogate distance is any measure that yields the same - rank as the distance, but is more efficient to compute. For example, for the - Euclidean metric, the surrogate distance is the squared-euclidean distance. + rank as the distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. """ return self.dist(x1, x2, size) @@ -343,8 +356,9 @@ cdef class DistanceMetric: """Convert the rank-preserving surrogate distance to the distance. The surrogate distance is any measure that yields the same rank as the - distance, but is more efficient to compute. For example, for the - Euclidean metric, the surrogate distance is the squared-euclidean distance. + distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. Parameters ---------- @@ -362,8 +376,9 @@ cdef class DistanceMetric: """Convert the true distance to the rank-preserving surrogate distance. The surrogate distance is any measure that yields the same rank as the - distance, but is more efficient to compute. For example, for the - Euclidean metric, the surrogate distance is the squared-euclidean distance. + distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. Parameters ---------- @@ -1150,3 +1165,164 @@ cdef class PyFuncDistance(DistanceMetric): cdef inline double fmax(double a, double b) nogil: return max(a, b) + + +###################################################################### +# Datasets Pair Classes +cdef class DatasetsPair: + """Abstract class which wraps a pair of datasets (X, Y). + + This class allows computing distances between a single pair of rows of + of X and Y at a time given the pair of their indices (i, j). This class is + specialized for each metric thanks to the :func:`get_for` factory classmethod. + + The handling of parallelization over chunks to compute the distances + and aggregation for several rows at a time is done in dedicated + subclasses of PairwiseDistancesReduction that in-turn rely on + subclasses of DatasetsPair for each pair of rows in the data. The goal + is to make it possible to decouple the generic parallelization and + aggregation logic from metric-specific computation as much as + possible. + + X and Y can be stored as C-contiguous np.ndarrays or CSR matrices + in subclasses. + + This class avoids the overhead of dispatching distance computations + to :class:`sklearn.metrics.DistanceMetric` based on the physical + representation of the vectors (sparse vs. dense). It makes use of + cython.final to remove the overhead of dispatching method calls. + + Parameters + ---------- + distance_metric: DistanceMetric + The distance metric responsible for computing distances + between two vectors of (X, Y). + """ + + @classmethod + def get_for( + cls, + X, + Y, + str metric="euclidean", + dict metric_kwargs=None, + ) -> DatasetsPair: + """Return the DatasetsPair implementation for the given arguments. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + If provided as a ndarray, it must be C-contiguous. + If provided as a sparse matrix, it must be in CSR format. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + If provided as a ndarray, it must be C-contiguous. + If provided as a sparse matrix, it must be in CSR format. + + metric : str, default='euclidean' + The distance metric to compute between rows of X and Y. + The default metric is a fast implementation of the Euclidean + metric. For a list of available metrics, see the documentation + of :class:`~sklearn.metrics.DistanceMetric`. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + Returns + ------- + datasets_pair: DatasetsPair + The suited DatasetsPair implementation. + """ + cdef: + DistanceMetric distance_metric = DistanceMetric.get_metric( + metric, + **(metric_kwargs or {}) + ) + + if not(X.dtype == Y.dtype == np.float64): + raise ValueError( + f"Only 64bit float datasets are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}" + ) + + # Metric-specific checks that do not replace nor duplicate `check_array`. + distance_metric._validate_data(X) + distance_metric._validate_data(Y) + + # TODO: dispatch to other dataset pairs for sparse support once available: + if issparse(X) or issparse(Y): + raise ValueError("Only dense datasets are supported for X and Y.") + + return DenseDenseDatasetsPair(X, Y, distance_metric) + + def __init__(self, DistanceMetric distance_metric): + self.distance_metric = distance_metric + + cdef ITYPE_t n_samples_X(self) nogil: + """Number of samples in X.""" + # This is a abstract method. + # This _must_ always be overwritten in subclasses. + # TODO: add "with gil: raise" here when supporting Cython 3.0 + return -999 + + cdef ITYPE_t n_samples_Y(self) nogil: + """Number of samples in Y.""" + # This is a abstract method. + # This _must_ always be overwritten in subclasses. + # TODO: add "with gil: raise" here when supporting Cython 3.0 + return -999 + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.dist(i, j) + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: + # This is a abstract method. + # This _must_ always be overwritten in subclasses. + # TODO: add "with gil: raise" here when supporting Cython 3.0 + return -1 + +@final +cdef class DenseDenseDatasetsPair(DatasetsPair): + """Compute distances between row vectors of two arrays. + + Parameters + ---------- + X: ndarray of shape (n_samples_X, n_features) + Rows represent vectors. Must be C-contiguous. + + Y: ndarray of shape (n_samples_Y, n_features) + Rows represent vectors. Must be C-contiguous. + + distance_metric: DistanceMetric + The distance metric responsible for computing distances + between two row vectors of (X, Y). + """ + + def __init__(self, X, Y, DistanceMetric distance_metric): + super().__init__(distance_metric) + # Arrays have already been checked + self.X = X + self.Y = Y + self.d = X.shape[1] + + @final + cdef ITYPE_t n_samples_X(self) nogil: + return self.X.shape[0] + + @final + cdef ITYPE_t n_samples_Y(self) nogil: + return self.Y.shape[0] + + @final + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.distance_metric.rdist(&self.X[i, 0], + &self.Y[j, 0], + self.d) + + @final + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.distance_metric.dist(&self.X[i, 0], + &self.Y[j, 0], + self.d) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx new file mode 100644 index 0000000000000..830df08e1a952 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -0,0 +1,839 @@ +# Pairwise Distances Reductions +# ============================= +# +# Author: Julien Jerphanion +# +# +# The abstractions defined here are used in various algorithms performing +# the same structure of operations on distances between row vectors +# of a datasets pair (X, Y). +# +# Importantly, the core of the computation is chunked to make sure that the pairwise +# distance chunk matrices stay in CPU cache before applying the final reduction step. +# Furthermore, the chunking strategy is also used to leverage OpenMP-based parallelism +# (using Cython prange loops) which gives another multiplicative speed-up in +# favorable cases on many-core machines. +cimport numpy as np +import numpy as np +import warnings +import scipy.sparse + +from .. import get_config +from libc.stdlib cimport free, malloc +from libc.float cimport DBL_MAX +from cython cimport final +from cython.parallel cimport parallel, prange + +from ._dist_metrics cimport DatasetsPair +from ..utils._heap cimport simultaneous_sort, heap_push +from ..utils._openmp_helpers cimport _openmp_thread_num +from ..utils._typedefs cimport ITYPE_t, DTYPE_t + +from numbers import Integral +from typing import List +from scipy.sparse import issparse +from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING +from ..utils import check_scalar +from ..utils.fixes import threadpool_limits +from ..utils._openmp_helpers import _openmp_effective_n_threads +from ..utils._typedefs import ITYPE, DTYPE + + +np.import_array() + +cdef class PairwiseDistancesReduction: + """Abstract base class for pairwise distance computation & reduction. + + Subclasses of this class compute pairwise distances between a set of + row vectors of X and another set of row vectors pf Y and apply a reduction on top. + The reduction takes a matrix of pairwise distances between rows of X and Y + as input and outputs an aggregate data-structure for each row of X. + The aggregate values are typically smaller than the number of rows in Y, + hence the term reduction. + + For computational reasons, it is interesting to perform the reduction on + the fly on chunks of rows of X and Y so as to keep intermediate + data-structures in CPU cache and avoid unnecessary round trips of large + distance arrays with the RAM that would otherwise severely degrade the + speed by making the overall processing memory-bound. + + The base class provides the generic chunked parallelization template using + OpenMP loops (Cython prange), either on rows of X or rows of Y depending on + their respective sizes. + + The subclasses are specialized for reduction. + + The actual distance computation for a given pair of rows of X and Y are + delegated to format-specific subclasses of the DatasetsPair companion base + class. + + Parameters + ---------- + datasets_pair: DatasetsPair + The pair of dataset to use. + + chunk_size: int, default=None + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + n_threads: int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on :method:`~PairwiseDistancesReduction.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + """ + + cdef: + readonly DatasetsPair datasets_pair + + # The number of threads that can be used is stored in effective_n_threads. + # + # The number of threads to use in the parallelisation strategy + # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: + # for small datasets, less threads might be needed to loop over pair of chunks. + # + # Hence the number of threads that _will_ be used for looping over chunks + # is stored in chunks_n_threads, allowing solely using what we need. + # + # Thus, an invariant is: + # + # chunks_n_threads <= effective_n_threads + # + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + + ITYPE_t n_samples_chunk, chunk_size + + ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk + ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk + + bint execute_in_parallel_on_Y + + @classmethod + def valid_metrics(cls) -> List[str]: + excluded = { + "pyfunc", # is relatively slow because we need to coerce data as np arrays + "mahalanobis", # is numerically unstable + # TODO: In order to support discrete distance metrics, we need to have a + # stable simultaneous sort which preserves the order of the input. + # The best might be using std::stable_sort and a Comparator taking an + # Arrays of Structures instead of Structure of Arrays (currently used). + "hamming", + *BOOL_METRICS, + } + return sorted(set(METRIC_MAPPING.keys()) - excluded) + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + """Return True if the PairwiseDistancesReduction can be used for the given parameters. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + Returns + ------- + True if the PairwiseDistancesReduction can be used, else False. + """ + # TODO: support sparse arrays and 32 bits + return (not issparse(X) and X.dtype == np.float64 and + not issparse(Y) and Y.dtype == np.float64 and + metric in cls.valid_metrics()) + + def __cinit__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + n_threads=None, + strategy=None, + *args, + **kwargs, + ): + cdef: + ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks + + if chunk_size is None: + chunk_size = get_config().get("pairwise_dist_chunk_size", 256) + + self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) + + self.effective_n_threads = _openmp_effective_n_threads(n_threads) + + self.datasets_pair = datasets_pair + + self.n_samples_X = datasets_pair.n_samples_X() + self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) + X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk + X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk + self.X_n_chunks = X_n_full_chunks + (X_n_samples_remainder != 0) + + if X_n_samples_remainder != 0: + self.X_n_samples_last_chunk = X_n_samples_remainder + else: + self.X_n_samples_last_chunk = self.X_n_samples_chunk + + self.n_samples_Y = datasets_pair.n_samples_Y() + self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) + Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk + Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk + self.Y_n_chunks = Y_n_full_chunks + (Y_n_samples_remainder != 0) + + if Y_n_samples_remainder != 0: + self.Y_n_samples_last_chunk = Y_n_samples_remainder + else: + self.Y_n_samples_last_chunk = self.Y_n_samples_chunk + + if strategy is None: + strategy = get_config().get("pairwise_dist_parallel_strategy", 'auto') + + if strategy not in ('parallel_on_X', 'parallel_on_Y', 'auto'): + raise RuntimeError(f"strategy must be 'parallel_on_X, 'parallel_on_Y', " + f"or 'auto', but currently strategy='{self.strategy}'.") + + if strategy == 'auto': + # This is a simple heuristic whose constant for the + # comparison has been chosen based on experiments. + if 4 * self.chunk_size * self.effective_n_threads < self.n_samples_X: + strategy = 'parallel_on_X' + else: + strategy = 'parallel_on_Y' + + self.execute_in_parallel_on_Y = strategy == "parallel_on_Y" + + # Not using less, not using more. + self.chunks_n_threads = min( + self.Y_n_chunks if self.execute_in_parallel_on_Y else self.X_n_chunks, + self.effective_n_threads, + ) + + @final + cdef void _parallel_on_X(self) nogil: + """Compute the pairwise distances of each row vector of X on Y + by parallelizing computation on the outer loop on chunks of X + and reduce them. + + This strategy dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + thread_num = _openmp_thread_num() + + # Allocating thread datastructures + self._parallel_on_X_parallel_init(thread_num) + + for X_chunk_idx in prange(self.X_n_chunks, schedule='static'): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1: + X_end = X_start + self.X_n_samples_last_chunk + else: + X_end = X_start + self.X_n_samples_chunk + + # Reinitializing thread datastructures for the new X chunk + self._parallel_on_X_init_chunk(thread_num, X_start) + + for Y_chunk_idx in range(self.Y_n_chunks): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1: + Y_end = Y_start + self.Y_n_samples_last_chunk + else: + Y_end = Y_start + self.Y_n_samples_chunk + + self._compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + + # Adjusting thread datastructures on the full pass on Y + self._parallel_on_X_prange_iter_finalize(thread_num, X_start, X_end) + + # end: for X_chunk_idx + + # Deallocating thread datastructures + self._parallel_on_X_parallel_finalize(thread_num) + + # end: with nogil, parallel + return + + @final + cdef void _parallel_on_Y(self) nogil: + """Compute the pairwise distances of each row vector of X on Y + by parallelizing computation on the inner loop on chunks of Y + and reduce them. + + This strategy dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t thread_num + + # Allocating datastructures shared by all threads + self._parallel_on_Y_init() + + for X_chunk_idx in range(self.X_n_chunks): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1: + X_end = X_start + self.X_n_samples_last_chunk + else: + X_end = X_start + self.X_n_samples_chunk + + with nogil, parallel(num_threads=self.chunks_n_threads): + thread_num = _openmp_thread_num() + + # Initializing datastructures used in this thread + self._parallel_on_Y_parallel_init(thread_num) + + for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1: + Y_end = Y_start + self.Y_n_samples_last_chunk + else: + Y_end = Y_start + self.Y_n_samples_chunk + + self._compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + # end: prange + + # Note: we don't need a _parallel_on_Y_finalize similarly. + # This can be introduced if needed. + + # end: with nogil, parallel + + # Synchronizing the thread datastructures with the main ones + self._parallel_on_Y_synchronize(X_start, X_end) + + # end: for X_chunk_idx + # Deallocating temporary datastructures and adjusting main datastructures + self._parallel_on_Y_finalize() + return + + # Placeholder methods which have to be implemented + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + """Compute the pairwise distances on two chunks of X and Y and reduce them. + + This is THE core computational method of PairwiseDistanceReductions. + This must be implemented in subclasses. + """ + return + + def _finalize_results(self, bint return_distance): + """Callback adapting datastructures before returning results. + + This must be implemented in subclasses. + """ + return None + + # Placeholder methods which can be implemented + + cdef void compute_exact_distances(self) nogil: + """Convert rank-preserving distances to exact distances or recompute them.""" + return + + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + """Allocate datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Interact with datastructures after a reduction on chunks.""" + return + + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + """Interact with datastructures after executing all the reductions.""" + return + + cdef void _parallel_on_Y_init( + self, + ) nogil: + """Allocate datastructures used in all threads.""" + return + + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Update thread datastructures before leaving a parallel region.""" + return + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + """Update datastructures after executing all the reductions.""" + return + +cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): + """Compute the argkmin of row vectors of X on the ones of Y. + + For each row vector of X, computes the indices of k first the rows + vectors of Y with the smallest distances. + + PairwiseDistancesArgKmin is typically used to perform + bruteforce k-nearest neighbors queries. + + Parameters + ---------- + datasets_pair: DatasetsPair + The dataset pairs (X, Y) for the reduction. + + chunk_size: int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + n_threads: int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on + :meth:`~PairwiseDistancesArgKmin.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + k: int, default=1 + The k for the argkmin reduction. + """ + + cdef: + ITYPE_t k + + ITYPE_t[:, ::1] argkmin_indices + DTYPE_t[:, ::1] argkmin_distances + + # Used as array of pointers to private datastructures used in threads. + DTYPE_t ** heaps_r_distances_chunks + ITYPE_t ** heaps_indices_chunks + + @classmethod + def compute( + cls, + X, + Y, + ITYPE_t k, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + n_threads=None, + str strategy=None, + bint return_distance=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + k : int + The k for the argkmin reduction. + + metric : str, default='euclidean' + The distance metric to use for argkmin. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + n_threads : int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on + :meth:`~PairwiseDistancesArgKmin.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + If return_distance=False: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + + If return_distance=True: + - argkmin_indices : ndarray of shape (n_samples_X, k) + Indices of the argkmin for each vector in X. + - argkmin_distances : ndarray of shape (n_samples_X, k) + Distances to the argkmin for each vector in X. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private :meth:`PairwiseDistancesArgKmin._compute` + instance method of the most appropriate :class:`PairwiseDistancesArgKmin` + concrete implementation. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the interface entirely from the + implementation details whilst maintaining RAII. + """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various back-end and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. + + pda = PairwiseDistancesArgKmin( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pda.execute_in_parallel_on_Y: + pda._parallel_on_Y() + else: + pda._parallel_on_X() + + return pda._finalize_results(return_distance) + + def __cinit__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + n_threads=None, + strategy=None, + *args, + **kwargs, + ): + # Allocating pointers to datastructures but not the datastructures themselves. + # There are as many pointers as effective threads. + # + # For the sake of explicitness: + # - when parallelizing on X, the pointers of those heaps are referencing + # (with proper offsets) addresses of the two main heaps (see bellow) + # - when parallelizing on Y, the pointers of those heaps are referencing + # small heaps which are thread-wise-allocated and whose content will be + # merged with the main heaps'. + self.heaps_r_distances_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + self.heaps_indices_chunks = malloc( + sizeof(ITYPE_t *) * self.chunks_n_threads + ) + + def __init__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + n_threads=None, + strategy=None, + ITYPE_t k=1, + ): + self.k = check_scalar(k, "k", Integral, min_val=1) + + # Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`. + self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) + self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) + + def __dealloc__(self): + if self.heaps_indices_chunks is not NULL: + free(self.heaps_indices_chunks) + + if self.heaps_r_distances_chunks is not NULL: + free(self.heaps_r_distances_chunks) + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + ITYPE_t n_samples_X = X_end - X_start + ITYPE_t n_samples_Y = Y_end - Y_start + DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Pushing the distances and their associated indices on a heap + # which by construction will keep track of the argkmin. + for i in range(n_samples_X): + for j in range(n_samples_Y): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), + Y_start + j, + ) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ) nogil: + # As this strategy is embarrassingly parallel, we can set each + # thread's heaps pointer to the proper position on the main heaps. + self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0] + self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] + + @final + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + + # Sorting the main heaps portion associated to `X[X_start:X_end]` + # in ascending order w.r.t the distances. + for idx in range(X_end - X_start): + simultaneous_sort( + self.heaps_r_distances_chunks[thread_num] + idx * self.k, + self.heaps_indices_chunks[thread_num] + idx * self.k, + self.k + ) + + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef: + # Maximum number of scalar elements (the last chunks can be smaller) + ITYPE_t heaps_size = self.X_n_samples_chunk * self.k + ITYPE_t thread_num + + # The allocation is done in parallel for data locality purposes: this way + # the heaps used in each threads are allocated in pages which are closer + # to the CPU core used by the thread. + # See comments about First Touch Placement Policy: + # https://www.openmp.org/wp-content/uploads/openmp-webinar-vanderPas-20210318.pdf #noqa + for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True, + num_threads=self.chunks_n_threads): + # As chunks of X are shared across threads, so must their + # heaps. To solve this, each thread has its own heaps + # which are then synchronised back in the main ones. + self.heaps_r_distances_chunks[thread_num] = malloc( + heaps_size * sizeof(DTYPE_t) + ) + self.heaps_indices_chunks[thread_num] = malloc( + heaps_size * sizeof(ITYPE_t) + ) + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + # Initialising heaps (memset can't be used here) + for idx in range(self.X_n_samples_chunk * self.k): + self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX + self.heaps_indices_chunks[thread_num][idx] = -1 + + @final + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx, thread_num + with nogil, parallel(num_threads=self.effective_n_threads): + # Synchronising the thread heaps with the main heaps. + # This is done in parallel sample-wise (no need for locks). + # + # This might break each thread's data locality as each heap which + # was allocated in a thread is being now being used in several threads. + # + # Still, this parallel pattern has shown to be efficient in practice. + for idx in prange(X_end - X_start, schedule="static"): + for thread_num in range(self.chunks_n_threads): + for jdx in range(self.k): + heap_push( + &self.argkmin_distances[X_start + idx, 0], + &self.argkmin_indices[X_start + idx, 0], + self.k, + self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], + self.heaps_indices_chunks[thread_num][idx * self.k + jdx], + ) + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef: + ITYPE_t idx, thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + # Deallocating temporary datastructures + for thread_num in prange(self.chunks_n_threads, schedule='static'): + free(self.heaps_r_distances_chunks[thread_num]) + free(self.heaps_indices_chunks[thread_num]) + + # Sorting the main in ascending order w.r.t the distances. + # This is done in parallel sample-wise (no need for locks). + for idx in prange(self.n_samples_X, schedule='static'): + simultaneous_sort( + &self.argkmin_distances[idx, 0], + &self.argkmin_indices[idx, 0], + self.k, + ) + return + + cdef void compute_exact_distances(self) nogil: + cdef: + ITYPE_t i, j + ITYPE_t[:, ::1] Y_indices = self.argkmin_indices + DTYPE_t[:, ::1] distances = self.argkmin_distances + for i in prange(self.n_samples_X, schedule='static', nogil=True, + num_threads=self.effective_n_threads): + for j in range(self.k): + distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against eventual -0., causing nan production. + max(distances[i, j], 0.) + ) + + def _finalize_results(self, bint return_distance=False): + if return_distance: + # We need to recompute distances because we relied on + # surrogate distances for the reduction. + self.compute_exact_distances() + return np.asarray(self.argkmin_indices), np.asarray(self.argkmin_distances) + + return np.asarray(self.argkmin_indices) diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index 69925a3590be6..1c26d9969397c 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -26,6 +26,13 @@ def configuration(parent_package="", top_path=None): libraries=libraries, ) + config.add_extension( + "_pairwise_distances_reduction", + sources=["_pairwise_distances_reduction.pyx"], + include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")], + libraries=libraries, + ) + config.add_subpackage("tests") return config diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py new file mode 100644 index 0000000000000..a4d51e4662740 --- /dev/null +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -0,0 +1,357 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal, assert_allclose +from scipy.sparse import csr_matrix +from scipy.spatial.distance import cdist + +from sklearn.metrics._pairwise_distances_reduction import ( + PairwiseDistancesReduction, + PairwiseDistancesArgKmin, +) + +from sklearn.utils.fixes import sp_version, parse_version + +# Common supported metric between scipy.spatial.distance.cdist +# and PairwiseDistancesReduction. +# This allows constructing tests to check consistency of results +# of concrete PairwiseDistancesReduction on some metrics using APIs +# from scipy and numpy. +CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS = [ + "braycurtis", + "canberra", + "chebyshev", + "cityblock", + "euclidean", + "minkowski", + "seuclidean", +] + + +def _get_dummy_metric_params_list(metric: str, n_features: int): + """Return list of dummy DistanceMetric kwargs for tests.""" + + # Distinguishing on cases not to compute unneeded datastructures. + rng = np.random.RandomState(1) + + if metric == "minkowski": + minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)] + if sp_version >= parse_version("1.8.0.dev0"): + # TODO: remove the test once we no longer support scipy < 1.8.0. + # Recent scipy versions accept weights in the Minkowski metric directly: + # type: ignore + minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) + + return minkowski_kwargs + + # TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0. + if metric == "wminkowski": + weights = rng.random_sample(n_features) + weights /= weights.sum() + wminkowski_kwargs = [dict(p=1.5, w=weights)] + if sp_version < parse_version("1.8.0.dev0"): + # wminkowski was removed in scipy 1.8.0 but should work for previous + # versions. + wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features))) + return wminkowski_kwargs + + if metric == "seuclidean": + return [dict(V=rng.rand(n_features))] + + # Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric. + # In those cases, no kwargs is needed. + return [{}] + + +def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices): + assert_array_equal( + ref_indices, + indices, + err_msg="Query vectors have different neighbors' indices", + ) + assert_allclose( + ref_dist, + dist, + err_msg="Query vectors have different neighbors' distances", + rtol=1e-7, + ) + + +ASSERT_RESULT = { + PairwiseDistancesArgKmin: assert_argkmin_results_equality, +} + + +def test_pairwise_distances_reduction_is_usable_for(): + rng = np.random.RandomState(0) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + metric = "euclidean" + assert PairwiseDistancesReduction.is_usable_for(X, Y, metric) + assert not PairwiseDistancesReduction.is_usable_for( + X.astype(np.int64), Y.astype(np.int64), metric + ) + + assert not PairwiseDistancesReduction.is_usable_for(X, Y, metric="pyfunc") + # TODO: remove once 32 bits datasets are supported + assert not PairwiseDistancesReduction.is_usable_for(X.astype(np.float32), Y, metric) + assert not PairwiseDistancesReduction.is_usable_for(X, Y.astype(np.int32), metric) + + # TODO: remove once sparse matrices are supported + assert not PairwiseDistancesReduction.is_usable_for(csr_matrix(X), Y, metric) + assert not PairwiseDistancesReduction.is_usable_for(X, csr_matrix(Y), metric) + + +def test_argkmin_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + k = 5 + metric = "euclidean" + + msg = ( + "Only 64bit float datasets are supported at this time, " + "got: X.dtype=float32 and Y.dtype=float64" + ) + with pytest.raises(ValueError, match=msg): + PairwiseDistancesArgKmin.compute( + X=X.astype(np.float32), Y=Y, k=k, metric=metric + ) + + msg = ( + "Only 64bit float datasets are supported at this time, " + "got: X.dtype=float64 and Y.dtype=int32" + ) + with pytest.raises(ValueError, match=msg): + PairwiseDistancesArgKmin.compute(X=X, Y=Y.astype(np.int32), k=k, metric=metric) + + with pytest.raises(ValueError, match="k == -1, must be >= 1."): + PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=-1, metric=metric) + + with pytest.raises(ValueError, match="k == 0, must be >= 1."): + PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=0, metric=metric) + + with pytest.raises(ValueError, match="Unrecognized metric"): + PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=k, metric="wrong metric") + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + PairwiseDistancesArgKmin.compute( + X=np.array([1.0, 2.0]), Y=Y, k=k, metric=metric + ) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + PairwiseDistancesArgKmin.compute( + X=np.asfortranarray(X), Y=Y, k=k, metric=metric + ) + + +@pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("chunk_size", [50, 512, 1024]) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_chunk_size_agnosticism( + PairwiseDistancesReduction, + seed, + n_samples, + chunk_size, + n_features=100, + dtype=np.float64, +): + # Results should not depend on the chunk size + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + ref_indices, ref_dist = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + return_distance=True, + ) + + indices, dist = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + chunk_size=chunk_size, + return_distance=True, + ) + + ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices) + + +@pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("chunk_size", [50, 512, 1024]) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_n_threads_agnosticism( + PairwiseDistancesReduction, + seed, + n_samples, + chunk_size, + n_features=100, + dtype=np.float64, +): + # Results should not depend on the number of threads + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + ref_indices, ref_dist = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + return_distance=True, + ) + + indices, dist = PairwiseDistancesReduction.compute( + X, Y, parameter, n_threads=1, return_distance=True + ) + + ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices) + + +@pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("metric", PairwiseDistancesReduction.valid_metrics()) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_strategies_consistency( + PairwiseDistancesReduction, + metric, + n_samples, + seed, + n_features=10, + dtype=np.float64, +): + + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) + Y = np.ascontiguousarray(Y[:, :2]) + + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + indices_par_X, dist_par_X = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + metric=metric, + # Taking the first + metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + # To be sure to use parallelization + chunk_size=n_samples // 4, + strategy="parallel_on_X", + return_distance=True, + ) + + indices_par_Y, dist_par_Y = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + metric=metric, + # Taking the first + metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + # To be sure to use parallelization + chunk_size=n_samples // 4, + strategy="parallel_on_Y", + return_distance=True, + ) + + ASSERT_RESULT[PairwiseDistancesReduction]( + dist_par_X, + dist_par_Y, + indices_par_X, + indices_par_Y, + ) + + +# Concrete PairwiseDistancesReductions tests + + +@pytest.mark.parametrize("n_features", [50, 500]) +@pytest.mark.parametrize("translation", [0, 1e8]) +@pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) +@pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) +def test_pairwise_distances_argkmin( + n_features, + translation, + metric, + strategy, + n_samples=100, + k=10, + dtype=np.float64, +): + rng = np.random.RandomState(0) + spread = 1000 + X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread + Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread + + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) + Y = np.ascontiguousarray(Y[:, :2]) + + metric_kwargs = _get_dummy_metric_params_list(metric, n_features)[0] + + # Reference for argkmin results + dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) + # Taking argkmin (indices of the k smallest values) + argkmin_indices_ref = np.argsort(dist_matrix, axis=1)[:, :k] + # Getting the associated distances + argkmin_distances_ref = np.zeros(argkmin_indices_ref.shape, dtype=np.float64) + for row_idx in range(argkmin_indices_ref.shape[0]): + argkmin_distances_ref[row_idx] = dist_matrix[ + row_idx, argkmin_indices_ref[row_idx] + ] + + argkmin_indices, argkmin_distances = PairwiseDistancesArgKmin.compute( + X, + Y, + k, + metric=metric, + metric_kwargs=metric_kwargs, + return_distance=True, + # So as to have more than a chunk, forcing parallelism. + chunk_size=n_samples // 4, + strategy=strategy, + ) + + ASSERT_RESULT[PairwiseDistancesArgKmin]( + argkmin_distances, argkmin_distances_ref, argkmin_indices, argkmin_indices_ref + ) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index f78a9ff30b10a..e99eb5fc9db82 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -16,6 +16,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": 256, } # Not using as a context manager affects nothing @@ -28,6 +29,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": 256, } assert get_config()["assume_finite"] is False @@ -57,6 +59,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": 256, } # No positional arguments diff --git a/sklearn/utils/_openmp_helpers.pxd b/sklearn/utils/_openmp_helpers.pxd new file mode 100644 index 0000000000000..e57fc9bfa6bf5 --- /dev/null +++ b/sklearn/utils/_openmp_helpers.pxd @@ -0,0 +1,6 @@ +# Helpers to access OpenMP threads information +# +# Those interfaces act as indirections which allows the non-support of OpenMP +# for implementations which have been written for it. + +cdef int _openmp_thread_num() nogil diff --git a/sklearn/utils/_openmp_helpers.pyx b/sklearn/utils/_openmp_helpers.pyx index fb8920074a84e..cddd77ac42746 100644 --- a/sklearn/utils/_openmp_helpers.pyx +++ b/sklearn/utils/_openmp_helpers.pyx @@ -6,7 +6,7 @@ IF SKLEARN_OPENMP_PARALLELISM_ENABLED: def _openmp_parallelism_enabled(): """Determines whether scikit-learn has been built with OpenMP - + It allows to retrieve at runtime the information gathered at compile time. """ # SKLEARN_OPENMP_PARALLELISM_ENABLED is resolved at compile time during @@ -22,7 +22,7 @@ cpdef _openmp_effective_n_threads(n_threads=None): - if the ``OMP_NUM_THREADS`` environment variable is set, return ``openmp.omp_get_max_threads()`` - otherwise, return the minimum between ``openmp.omp_get_max_threads()`` - and the number of cpus, taking cgroups quotas into account. Cgroups + and the number of cpus, taking cgroups quotas into account. Cgroups quotas can typically be set by tools such as Docker. The result of ``omp_get_max_threads`` can be influenced by environment variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``. @@ -59,4 +59,13 @@ cpdef _openmp_effective_n_threads(n_threads=None): # OpenMP disabled at build-time => sequential mode return 1 - + +cdef inline int _openmp_thread_num() nogil: + """Return the number of the thread calling this function. + + If scikit-learn is built without OpenMP support, always return 0. + """ + IF SKLEARN_OPENMP_PARALLELISM_ENABLED: + return openmp.omp_get_thread_num() + ELSE: + return 0 From 60dcd130c99aa60dab2daf66ae31c3de7b20626f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 13 Jan 2022 16:41:47 +0100 Subject: [PATCH 02/30] MAINT Introduce FastEuclideanPairwiseArgKmin (#22065) Co-authored-by: Olivier Grisel Co-authored-by: Thomas J. Fan --- .../metrics/_pairwise_distances_reduction.pyx | 302 ++++++++++++++++-- .../test_pairwise_distances_reduction.py | 31 +- sklearn/utils/__init__.py | 35 +- sklearn/utils/_testing.py | 11 +- 4 files changed, 347 insertions(+), 32 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 830df08e1a952..76420b50a1b5e 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -16,7 +16,6 @@ cimport numpy as np import numpy as np import warnings -import scipy.sparse from .. import get_config from libc.stdlib cimport free, malloc @@ -24,7 +23,17 @@ from libc.float cimport DBL_MAX from cython cimport final from cython.parallel cimport parallel, prange -from ._dist_metrics cimport DatasetsPair +from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair +from ..utils._cython_blas cimport ( + BLAS_Order, + BLAS_Trans, + ColMajor, + NoTrans, + RowMajor, + Trans, + _dot, + _gemm, +) from ..utils._heap cimport simultaneous_sort, heap_push from ..utils._openmp_helpers cimport _openmp_thread_num from ..utils._typedefs cimport ITYPE_t, DTYPE_t @@ -33,7 +42,7 @@ from numbers import Integral from typing import List from scipy.sparse import issparse from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING -from ..utils import check_scalar +from ..utils import check_scalar, _in_unstable_openblas_configuration from ..utils.fixes import threadpool_limits from ..utils._openmp_helpers import _openmp_effective_n_threads from ..utils._typedefs import ITYPE, DTYPE @@ -41,11 +50,37 @@ from ..utils._typedefs import ITYPE, DTYPE np.import_array() +cpdef DTYPE_t[::1] _sqeuclidean_row_norms( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * X_ptr = &X[0, 0] + ITYPE_t idx = 0 + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) + + for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): + squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + + return squared_row_norms + + + cdef class PairwiseDistancesReduction: """Abstract base class for pairwise distance computation & reduction. Subclasses of this class compute pairwise distances between a set of - row vectors of X and another set of row vectors pf Y and apply a reduction on top. + row vectors of X and another set of row vectors of Y and apply a reduction on top. The reduction takes a matrix of pairwise distances between rows of X and Y as input and outputs an aggregate data-structure for each row of X. The aggregate values are typically smaller than the number of rows in Y, @@ -180,14 +215,12 @@ cdef class PairwiseDistancesReduction: not issparse(Y) and Y.dtype == np.float64 and metric in cls.valid_metrics()) - def __cinit__( + def __init__( self, DatasetsPair datasets_pair, chunk_size=None, n_threads=None, strategy=None, - *args, - **kwargs, ): cdef: ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks @@ -611,13 +644,32 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): # For future work, this might can be an entrypoint to specialise operations # for various back-end and/or hardware and/or datatypes, and/or fused # {sparse, dense}-datasetspair etc. - - pda = PairwiseDistancesArgKmin( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - ) + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + # Specialized implementation with improved arithmetic intensity + # and vector instructions (SIMD) by processing several vectors + # at time to leverage a call to the BLAS GEMM routine as explained + # in more details in the docstring. + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesArgKmin( + X=X, Y=Y, k=k, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + strategy=strategy, + metric_kwargs=metric_kwargs, + ) + else: + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pda = PairwiseDistancesArgKmin( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -629,15 +681,22 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return pda._finalize_results(return_distance) - def __cinit__( + def __init__( self, DatasetsPair datasets_pair, chunk_size=None, n_threads=None, strategy=None, - *args, - **kwargs, - ): + ITYPE_t k=1, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + ) + self.k = check_scalar(k, "k", Integral, min_val=1) + # Allocating pointers to datastructures but not the datastructures themselves. # There are as many pointers as effective threads. # @@ -654,16 +713,6 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): sizeof(ITYPE_t *) * self.chunks_n_threads ) - def __init__( - self, - DatasetsPair datasets_pair, - chunk_size=None, - n_threads=None, - strategy=None, - ITYPE_t k=1, - ): - self.k = check_scalar(k, "k", Integral, min_val=1) - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin.compute`. self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) @@ -837,3 +886,200 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): return np.asarray(self.argkmin_indices), np.asarray(self.argkmin_distances) return np.asarray(self.argkmin_indices) + + +cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): + """Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance. + + The full pairwise squared distances matrix is computed as follows: + + ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||² + + The middle term gets computed efficiently bellow using BLAS Level 3 GEMM. + + Notes + ----- + This implementation has a superior arithmetic intensity and hence + better running time when the alternative is IO bound, but it can suffer + from numerical instability caused by catastrophic cancellation potentially + introduced by the subtraction in the arithmetic expression above. + """ + + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + # Buffers for GEMM + DTYPE_t ** dist_middle_terms_chunks + bint use_squared_distances + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and + not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + ITYPE_t k, + bint use_squared_distances=False, + chunk_size=None, + n_threads=None, + strategy=None, + metric_kwargs=None, + ): + if metric_kwargs is not None and len(metric_kwargs) > 0: + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't" + f"usable for this case ({self.__class__.__name__}) and will be ignored.", + UserWarning, + stacklevel=3, + ) + + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + k=k, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = self.datasets_pair + self.X, self.Y = datasets_pair.X, datasets_pair.Y + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + else: + self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms(self.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + + def __dealloc__(self): + if self.dist_middle_terms_chunks is not NULL: + free(self.dist_middle_terms_chunks) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesArgKmin.compute_exact_distances(self) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num) + + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self, thread_num) + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin._parallel_on_Y_init(self) + + for thread_num in range(self.chunks_n_threads): + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin._parallel_on_Y_finalize(self) + + for thread_num in range(self.chunks_n_threads): + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] + DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * A = & X_c[0, 0] + ITYPE_t lda = X_c.shape[1] + DTYPE_t * B = & Y_c[0, 0] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) + + # Pushing the distance and their associated indices on heaps + # which keep tracks of the argkmin. + for i in range(X_c.shape[0]): + for j in range(Y_c.shape[0]): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * Y_c.shape[0] + j] + + self.Y_norm_squared[j + Y_start] + ), + j + Y_start, + ) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index a4d51e4662740..e975aad55bf9c 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -7,8 +7,10 @@ from sklearn.metrics._pairwise_distances_reduction import ( PairwiseDistancesReduction, PairwiseDistancesArgKmin, + _sqeuclidean_row_norms, ) +from sklearn.metrics import euclidean_distances from sklearn.utils.fixes import sp_version, parse_version # Common supported metric between scipy.spatial.distance.cdist @@ -305,7 +307,7 @@ def test_strategies_consistency( @pytest.mark.parametrize("n_features", [50, 500]) -@pytest.mark.parametrize("translation", [0, 1e8]) +@pytest.mark.parametrize("translation", [0, 1e6]) @pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) @pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) def test_pairwise_distances_argkmin( @@ -330,7 +332,11 @@ def test_pairwise_distances_argkmin( metric_kwargs = _get_dummy_metric_params_list(metric, n_features)[0] # Reference for argkmin results - dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) + if metric == "euclidean": + # Compare to scikit-learn GEMM optimized implementation + dist_matrix = euclidean_distances(X, Y) + else: + dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) # Taking argkmin (indices of the k smallest values) argkmin_indices_ref = np.argsort(dist_matrix, axis=1)[:, :k] # Getting the associated distances @@ -355,3 +361,24 @@ def test_pairwise_distances_argkmin( ASSERT_RESULT[PairwiseDistancesArgKmin]( argkmin_distances, argkmin_distances_ref, argkmin_indices, argkmin_indices_ref ) + + +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("n_features", [5, 10, 100]) +@pytest.mark.parametrize("num_threads", [1, 2, 8]) +def test_sqeuclidean_row_norms( + seed, + n_samples, + n_features, + num_threads, + dtype=np.float64, +): + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + + sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2 + sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads)) + + assert_allclose(sq_row_norm_reference, sq_row_norm) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 3d8a1ca87d210..4b2261ad7c2f4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -26,7 +26,7 @@ from . import _joblib from ..exceptions import DataConversionWarning from .deprecation import deprecated -from .fixes import np_version, parse_version +from .fixes import np_version, parse_version, threadpool_info from ._estimator_html_repr import estimator_html_repr from .validation import ( as_float_array, @@ -81,6 +81,39 @@ _IS_32BIT = 8 * struct.calcsize("P") == 32 +def _in_unstable_openblas_configuration(): + """Return True if in an unstable configuration for OpenBLAS""" + + # Import libraries which might load OpenBLAS. + import numpy # noqa + import scipy # noqa + + modules_info = threadpool_info() + + open_blas_used = any(info["internal_api"] == "openblas" for info in modules_info) + if not open_blas_used: + return False + + # OpenBLAS 0.3.16 fixed unstability for arm64, see: + # https://github.com/xianyi/OpenBLAS/blob/1b6db3dbba672b4f8af935bd43a1ff6cff4d20b7/Changelog.txt#L56-L58 # noqa + openblas_arm64_stable_version = parse_version("0.3.16") + for info in modules_info: + if info["internal_api"] != "openblas": + continue + openblas_version = info.get("version") + openblas_architecture = info.get("architecture") + if openblas_version is None or openblas_architecture is None: + # Cannot be sure that OpenBLAS is good enough. Assume unstable: + return True + if ( + openblas_architecture == "neoversen1" + and parse_version(openblas_version) < openblas_arm64_stable_version + ): + # See discussions in https://github.com/numpy/numpy/issues/19411 + return True + return False + + class Bunch(dict): """Container object exposing keys as attributes. diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 1724063be2f43..6f58ce3f3b7b4 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -48,7 +48,12 @@ import joblib import sklearn -from sklearn.utils import IS_PYPY, _IS_32BIT, deprecated +from sklearn.utils import ( + IS_PYPY, + _IS_32BIT, + deprecated, + _in_unstable_openblas_configuration, +) from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import ( check_array, @@ -448,6 +453,10 @@ def set_random_state(estimator, random_state=0): os.environ.get("TRAVIS") == "true", reason="skip on travis" ) fails_if_pypy = pytest.mark.xfail(IS_PYPY, reason="not compatible with PyPy") + fails_if_unstable_openblas = pytest.mark.xfail( + _in_unstable_openblas_configuration(), + reason="OpenBLAS is unstable for this configuration", + ) skip_if_no_parallel = pytest.mark.skipif( not joblib.parallel.mp, reason="joblib is in serial mode" ) From 1abb44186fc534d4d2ed1c7f96103f9621dff261 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 13 Jan 2022 18:35:53 +0100 Subject: [PATCH 03/30] fixup! Merge branch 'main' into pairwise-distances-argkmin Remove duplicated Bunch --- sklearn/utils/__init__.py | 50 --------------------------------------- 1 file changed, 50 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 279e6cca8d335..1fc622f6a4538 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -116,56 +116,6 @@ def _in_unstable_openblas_configuration(): return False -class Bunch(dict): - """Container object exposing keys as attributes. - - Bunch objects are sometimes used as an output for functions and methods. - They extend dictionaries by enabling values to be accessed by key, - `bunch["value_key"]`, or by an attribute, `bunch.value_key`. - - Examples - -------- - >>> from sklearn.utils import Bunch - >>> b = Bunch(a=1, b=2) - >>> b['b'] - 2 - >>> b.b - 2 - >>> b.a = 3 - >>> b['a'] - 3 - >>> b.c = 6 - >>> b['c'] - 6 - """ - - def __init__(self, **kwargs): - super().__init__(kwargs) - - def __setattr__(self, key, value): - self[key] = value - - def __dir__(self): - return self.keys() - - def __getattr__(self, key): - try: - return self[key] - except KeyError: - raise AttributeError(key) - - def __setstate__(self, state): - # Bunch pickles generated with scikit-learn 0.16.* have an non - # empty __dict__. This causes a surprising behaviour when - # loading these pickles scikit-learn 0.17: reading bunch.key - # uses __dict__ but assigning to bunch.key use __setattr__ and - # only changes bunch['key']. More details can be found at: - # https://github.com/scikit-learn/scikit-learn/issues/6196. - # Overriding __setstate__ to be a noop has the effect of - # ignoring the pickled __dict__ - pass - - def safe_mask(X, mask): """Return a mask which is safe to use on X. From ac6f623d1639c52673d937b7a175312e37652e01 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Jan 2022 15:10:12 +0100 Subject: [PATCH 04/30] ENH Introduce PairwiseDistancesRadiusNeighborhood --- .../metrics/_pairwise_distances_reduction.pyx | 732 +++++++++++++++++- sklearn/metrics/setup.py | 1 + .../test_pairwise_distances_reduction.py | 74 +- 3 files changed, 802 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 76420b50a1b5e..9ae5437cf631a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -20,8 +20,11 @@ import warnings from .. import get_config from libc.stdlib cimport free, malloc from libc.float cimport DBL_MAX +from libcpp.vector cimport vector from cython cimport final +from cython.operator cimport dereference as deref from cython.parallel cimport parallel, prange +from cpython.ref cimport Py_INCREF from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair from ..utils._cython_blas cimport ( @@ -36,9 +39,10 @@ from ..utils._cython_blas cimport ( ) from ..utils._heap cimport simultaneous_sort, heap_push from ..utils._openmp_helpers cimport _openmp_thread_num -from ..utils._typedefs cimport ITYPE_t, DTYPE_t +from ..utils._typedefs cimport ITYPE_t, DTYPE_t, DITYPE_t +from ..utils._typedefs cimport ITYPECODE, DTYPECODE -from numbers import Integral +from numbers import Integral, Real from typing import List from scipy.sparse import issparse from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING @@ -50,6 +54,60 @@ from ..utils._typedefs import ITYPE, DTYPE np.import_array() +# TODO: change for `libcpp.algorithm.move` once Cython 3 is used +# Introduction in Cython: +# https://github.com/cython/cython/blob/05059e2a9b89bf6738a7750b905057e5b1e3fe2e/Cython/Includes/libcpp/algorithm.pxd#L47 #noqa +cdef extern from "" namespace "std" nogil: + OutputIt move[InputIt, OutputIt](InputIt first, InputIt last, OutputIt d_first) except + #noqa + +###################### +## std::vector to np.ndarray coercion +# As type covariance is not supported for C++ containers via Cython, +# we need to redefine fused types. +ctypedef fused vector_DITYPE_t: + vector[ITYPE_t] + vector[DTYPE_t] + + +ctypedef fused vector_vector_DITYPE_t: + vector[vector[ITYPE_t]] + vector[vector[DTYPE_t]] + + +cdef class StdVectorSentinel: + """Wraps a reference to a vector which will be deallocated with this object. + + When created, the StdVectorSentinel swaps the reference of its internal + vectors with the provided one (vec_ptr), thus making the StdVectorSentinel + manage the provided one's lifetime. + """ + pass + + +# We necessarily need to define two extension types extending StdVectorSentinel +# because we need to provide the dtype of the vector but can't use numeric fused types. +cdef class StdVectorSentinelDTYPE(StdVectorSentinel): + cdef vector[DTYPE_t] vec + + @staticmethod + cdef StdVectorSentinel create_for(vector[DTYPE_t] * vec_ptr): + # This initializes the object directly without calling __init__ + cdef StdVectorSentinelDTYPE sentinel = StdVectorSentinelDTYPE.__new__(StdVectorSentinelDTYPE) + sentinel.vec.swap(deref(vec_ptr)) + return sentinel + + +cdef class StdVectorSentinelITYPE(StdVectorSentinel): + cdef vector[ITYPE_t] vec + + @staticmethod + cdef StdVectorSentinel create_for(vector[ITYPE_t] * vec_ptr): + # This initializes the object directly without calling __init__ + cdef StdVectorSentinelITYPE sentinel = StdVectorSentinelITYPE.__new__(StdVectorSentinelITYPE) + sentinel.vec.swap(deref(vec_ptr)) + return sentinel + + cpdef DTYPE_t[::1] _sqeuclidean_row_norms( const DTYPE_t[:, ::1] X, ITYPE_t num_threads, @@ -74,6 +132,49 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms( return squared_row_norms +cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr): + """Create a numpy ndarray given a C++ vector. + + The numpy array buffer is the one of the C++ vector. + A StdVectorSentinel is registered as the base object for the numpy array, + freeing the C++ vector it encapsulates when the numpy array is freed. + """ + typenum = DTYPECODE if vector_DITYPE_t is vector[DTYPE_t] else ITYPECODE + cdef: + np.npy_intp size = deref(vect_ptr).size() + np.ndarray arr = np.PyArray_SimpleNewFromData(1, &size, typenum, + deref(vect_ptr).data()) + StdVectorSentinel sentinel + + if vector_DITYPE_t is vector[DTYPE_t]: + sentinel = StdVectorSentinelDTYPE.create_for(vect_ptr) + else: + sentinel = StdVectorSentinelITYPE.create_for(vect_ptr) + + # Makes the numpy array responsible of the life-cycle of its buffer. + # A reference to the StdVectorSentinel will be stolen by the call bellow, + # so we increase its reference counter. + # See: https://docs.python.org/3/c-api/intro.html#reference-count-details + Py_INCREF(sentinel) + np.PyArray_SetBaseObject(arr, sentinel) + return arr + + +cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( + vector_vector_DITYPE_t* vecs +): + """Coerce a std::vector of std::vector to a ndarray of ndarray.""" + cdef: + ITYPE_t n = deref(vecs).size() + np.ndarray[object, ndim=1] nd_arrays_of_nd_arrays = np.empty(n, + dtype=np.ndarray) + + for i in range(n): + nd_arrays_of_nd_arrays[i] = vector_to_nd_array(&(deref(vecs)[i])) + + return nd_arrays_of_nd_arrays + +##################### cdef class PairwiseDistancesReduction: @@ -1083,3 +1184,630 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): ), j + Y_start, ) + +cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): + """Compute radius-based neighbors for two sets of vectors. + + For each row-vector X[i] of the queries X, find all the indices j of + row-vectors in Y such that: + + dist(X[i], Y[j]) <= radius + + The distance function `dist` depends on the values of the `metric` + and `metric_kwargs` parameters. + + Parameters + ---------- + datasets_pair: DatasetsPair + The dataset pair (X, Y) for the reduction. + + chunk_size: int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + n_threads: int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on + :meth:`~PairwiseDistancesRadiusNeighborhood.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + radius: float + The radius defining the neighborhood. + """ + + cdef: + DTYPE_t radius + + # DistanceMetric compute rank-preserving surrogate distance via rdist + # which are proxies necessitating less computations. + # We get the equivalent for the radius to be able to compare it against + # vectors' rank-preserving surrogate distances. + DTYPE_t r_radius + + # Neighbors indices and distances are returned as np.ndarrays of np.ndarrays. + # + # For this implementation, we want resizable buffers which we will wrap + # into numpy arrays at the end. std::vector comes as a handy interface + # for interacting efficiently with resizable buffers. + # + # Though it is possible to access their buffer address with + # std::vector::data, they can't be stolen: buffers lifetime + # is tight to their std::vector and are deallocated when + # std::vectors are. + # + # To solve this, we dynamically allocate std::vectors and then + # encapsulate them in a StdVectorSentinel responsible for + # freeing them when the associated np.ndarray is freed. + vector[vector[ITYPE_t]] * neigh_indices + vector[vector[DTYPE_t]] * neigh_distances + + # Used as array of pointers to private datastructures used in threads. + vector[vector[ITYPE_t]] ** neigh_indices_chunks + vector[vector[DTYPE_t]] ** neigh_distances_chunks + + bint sort_results + + @classmethod + def compute( + cls, + X, + Y, + DTYPE_t radius, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + n_threads=None, + str strategy=None, + bint return_distance=False, + bint sort_results=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + radius : float + The radius defining the neighborhood. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + n_threads : int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on + :meth:`~PairwiseDistancesRadiusNeighborhood.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. When `X.shape[0]` + is small but `Y.shape[0]` is large, 'parallel_on_Y' brings more opportunity + for parallelism and is therefore more efficient despite the synchronization + step at each iteration of the outer loop on chunks of `X`. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its neighbors if set to True. + + sort_results : boolean, default=False + Sort results with respect to distances between each X vector and its + neighbors if set to True. + + Returns + ------- + If return_distance=False: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + + If return_distance=True: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + - neighbors_distances : ndarray of n_samples_X ndarray + Distances to the neighbors for each vector in X. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private + :meth:`PairwiseDistancesRadiusNeighborhood._compute` instance method of + the most appropriate :class:`PairwiseDistancesRadiusNeighborhood` + concrete implementation. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the interface entirely from the + implementation details whilst maintaining RAII. + """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various back-end and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + # Specialized implementation with improved arithmetic intensity + # and vector instructions (SIMD) by processing several vectors + # at time to leverage a call to the BLAS GEMM routine as explained + # in more details in the docstring. + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesRadiusNeighborhood( + X=X, Y=Y, radius=radius, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + n_threads=n_threads, + strategy=strategy, + sort_results=sort_results, + ) + else: + # Fall back on a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pda = PairwiseDistancesRadiusNeighborhood( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + radius=radius, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + n_threads=n_threads, + strategy=strategy, + sort_results=sort_results, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pda.execute_in_parallel_on_Y: + pda._parallel_on_Y() + else: + pda._parallel_on_X() + + return pda._finalize_results(return_distance) + + + def __init__( + self, + DatasetsPair datasets_pair, + DTYPE_t radius, + chunk_size=None, + n_threads=None, + strategy=None, + sort_results=False, + metric_kwargs=None, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + ) + + self.radius = check_scalar(radius, "radius", Real, min_val=0) + self.r_radius = self.datasets_pair.distance_metric._dist_to_rdist(radius) + self.sort_results = sort_results + + # Allocating pointers to datastructures but not the datastructures themselves. + # There are as many pointers as effective threads. + # + # For the sake of explicitness: + # - when parallelizing on X, the pointers of those heaps are referencing + # self.neigh_distances and self.neigh_indices + # - when parallelizing on Y, the pointers of those heaps are referencing + # std::vectors of std::vectors which are thread-wise-allocated and whose + # content will be merged into self.neigh_distances and self.neigh_indices. + self.neigh_distances_chunks = malloc( + sizeof(self.neigh_distances) * self.chunks_n_threads + ) + self.neigh_indices_chunks = malloc( + sizeof(self.neigh_indices) * self.chunks_n_threads + ) + + # Temporary datastructures which will be coerced to numpy arrays on before + # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. + self.neigh_distances = new vector[vector[DTYPE_t]](self.n_samples_X) + self.neigh_indices = new vector[vector[ITYPE_t]](self.n_samples_X) + + def __dealloc__(self): + if self.neigh_distances_chunks is not NULL: + free(self.neigh_distances_chunks) + + if self.neigh_indices_chunks is not NULL: + free(self.neigh_indices_chunks) + + if self.neigh_indices is not NULL: + del self.neigh_indices + + if self.neigh_distances is not NULL: + del self.neigh_distances + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t r_dist_i_j + + for i in range(X_start, X_end): + for j in range(Y_start, Y_end): + r_dist_i_j = self.datasets_pair.surrogate_dist(i, j) + if r_dist_i_j <= self.r_radius: + deref(self.neigh_distances_chunks[thread_num])[i].push_back(r_dist_i_j) + deref(self.neigh_indices_chunks[thread_num])[i].push_back(j) + + def _finalize_results(self, bint return_distance=False): + if return_distance: + # We need to recompute distances because we relied on + # surrogate distances for the reduction. + self.compute_exact_distances() + return ( + coerce_vectors_to_nd_arrays(self.neigh_distances), + coerce_vectors_to_nd_arrays(self.neigh_indices), + ) + + return coerce_vectors_to_nd_arrays(self.neigh_indices) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ) nogil: + + # As this strategy is embarrassingly parallel, we can set the + # thread vectors' pointers to the main vectors'. + self.neigh_distances_chunks[thread_num] = self.neigh_distances + self.neigh_indices_chunks[thread_num] = self.neigh_indices + + @final + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + + # Sorting neighbors for each query vector of X + if self.sort_results: + for idx in range(X_start, X_end): + simultaneous_sort( + deref(self.neigh_distances)[idx].data(), + deref(self.neigh_indices)[idx].data(), + deref(self.neigh_indices)[idx].size() + ) + + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef: + ITYPE_t thread_num + # As chunks of X are shared across threads, so must datastructures to avoid race + # conditions: each thread has its own vectors of n_samples_X vectors which are + # then merged back in the main n_samples_X vectors. + for thread_num in range(self.chunks_n_threads): + self.neigh_distances_chunks[thread_num] = new vector[vector[DTYPE_t]](self.n_samples_X) + self.neigh_indices_chunks[thread_num] = new vector[vector[ITYPE_t]](self.n_samples_X) + + @final + cdef void _merge_vectors( + self, + ITYPE_t idx, + ITYPE_t num_threads, + ) nogil: + cdef: + ITYPE_t thread_num + ITYPE_t idx_n_elements = 0 + ITYPE_t last_element_idx = deref(self.neigh_indices)[idx].size() + + # Resizing buffers only once for the given number of elements. + for thread_num in range(num_threads): + idx_n_elements += deref(self.neigh_distances_chunks[thread_num])[idx].size() + + deref(self.neigh_distances)[idx].resize(last_element_idx + idx_n_elements) + deref(self.neigh_indices)[idx].resize(last_element_idx + idx_n_elements) + + # Moving the elements by range using the range first element + # as the reference for the insertion. + for thread_num in range(num_threads): + move( + deref(self.neigh_distances_chunks[thread_num])[idx].begin(), + deref(self.neigh_distances_chunks[thread_num])[idx].end(), + deref(self.neigh_distances)[idx].begin() + last_element_idx + ) + move( + deref(self.neigh_indices_chunks[thread_num])[idx].begin(), + deref(self.neigh_indices_chunks[thread_num])[idx].end(), + deref(self.neigh_indices)[idx].begin() + last_element_idx + ) + last_element_idx += deref(self.neigh_distances_chunks[thread_num])[idx].size() + + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef: + ITYPE_t idx, jdx, thread_num, idx_n_element, idx_current + + with nogil, parallel(num_threads=self.effective_n_threads): + # Merge vectors used in threads into the main ones. + # This is done in parallel sample-wise (no need for locks) + # using dynamic scheduling because we might not have + # the same number of neighbors for each query vector. + # TODO: compare 'dynamic' vs 'static' vs 'guided' + for idx in prange(self.n_samples_X, schedule='dynamic'): + self._merge_vectors(idx, self.chunks_n_threads) + + # The content of the vector have been std::moved. + # Hence they can't be used anymore and can only be deleted. + for thread_num in prange(self.chunks_n_threads, schedule='static'): + del self.neigh_distances_chunks[thread_num] + del self.neigh_indices_chunks[thread_num] + + # Sort in parallel in ascending order w.r.t the distances if requested. + if self.sort_results: + for idx in prange(self.n_samples_X, schedule='static'): + simultaneous_sort( + deref(self.neigh_distances)[idx].data(), + deref(self.neigh_indices)[idx].data(), + deref(self.neigh_indices)[idx].size() + ) + + return + + cdef void compute_exact_distances(self) nogil: + """Convert rank-preserving distances to pairwise distances in parallel.""" + cdef: + ITYPE_t i, j + + for i in prange(self.n_samples_X, nogil=True, schedule='dynamic', + num_threads=self.effective_n_threads): + for j in range(deref(self.neigh_indices)[i].size()): + deref(self.neigh_distances)[i][j] = ( + self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against eventual -0., causing nan production. + max(deref(self.neigh_distances)[i][j], 0.) + ) + ) + + +cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRadiusNeighborhood): + """Fast specialized alternative for PairwiseDistancesRadiusNeighborhood on EuclideanDistance. + + The full pairwise squared distances matrix is computed as follows: + + ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||² + + The middle term gets computed efficiently bellow using BLAS Level 3 GEMM. + + Notes + ----- + This implementation has a superior arithmetic intensity and hence + better running time when the alternative is IO bound, but it can suffer + from numerical instability caused by catastrophic cancellation potentially + introduced by the subtraction in the arithmetic expression above. + numerical precision is needed. + """ + + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + # Buffers for GEMM + DTYPE_t ** dist_middle_terms_chunks + bint use_squared_distances + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistancesRadiusNeighborhood.is_usable_for(X, Y, metric) + and not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + DTYPE_t radius, + bint use_squared_distances=False, + chunk_size=None, + n_threads=None, + strategy=None, + sort_results=False, + metric_kwargs=None, + ): + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + radius=radius, + chunk_size=chunk_size, + n_threads=n_threads, + strategy=strategy, + sort_results=sort_results, + metric_kwargs=metric_kwargs, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = self.datasets_pair + self.X, self.Y = datasets_pair.X, datasets_pair.Y + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + else: + self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms(self.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + if use_squared_distances: + # In this specialisation and this setup, the value passed to the radius is + # already considered to be the adapted radius, so we overwrite it. + self.r_radius = radius + + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc( + sizeof(DTYPE_t *) * self.effective_n_threads + ) + + def __dealloc__(self): + if self.dist_middle_terms_chunks is not NULL: + free(self.dist_middle_terms_chunks) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesRadiusNeighborhood.compute_exact_distances(self) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesRadiusNeighborhood._parallel_on_X_parallel_init(self, thread_num) + + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + PairwiseDistancesRadiusNeighborhood._parallel_on_X_parallel_finalize(self, thread_num) + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _parallel_on_Y_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesRadiusNeighborhood._parallel_on_Y_init(self) + + for thread_num in range(self.chunks_n_threads): + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesRadiusNeighborhood._parallel_on_Y_finalize(self) + + for thread_num in range(self.chunks_n_threads): + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + DTYPE_t squared_dist_i_j + + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * A = &X_c[0, 0] + ITYPE_t lda = X_c.shape[1] + DTYPE_t * B = &Y_c[0, 0] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) + + # Pushing the distance and their associated indices in vectors. + for i in range(X_c.shape[0]): + for j in range(Y_c.shape[0]): + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + squared_dist_i_j = ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * Y_c.shape[0] + j] + + self.Y_norm_squared[j + Y_start] + ) + if squared_dist_i_j <= self.r_radius: + deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) + deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index 1c26d9969397c..e03282af141f4 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -30,6 +30,7 @@ def configuration(parent_package="", top_path=None): "_pairwise_distances_reduction", sources=["_pairwise_distances_reduction.pyx"], include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")], + language="c++", libraries=libraries, ) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index e975aad55bf9c..8130e070be17d 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -7,6 +7,7 @@ from sklearn.metrics._pairwise_distances_reduction import ( PairwiseDistancesReduction, PairwiseDistancesArgKmin, + PairwiseDistancesRadiusNeighborhood, _sqeuclidean_row_norms, ) @@ -78,8 +79,25 @@ def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices): ) +def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, indices): + # We get arrays of arrays and we need to check for individual pairs + for i in range(ref_dist.shape[0]): + assert_array_equal( + ref_indices[i], + indices[i], + err_msg=f"Query vector #{i} has different neighbors' indices", + ) + assert_allclose( + ref_dist[i], + dist[i], + err_msg=f"Query vector #{i} has different neighbors' distances", + rtol=1e-7, + ) + + ASSERT_RESULT = { PairwiseDistancesArgKmin: assert_argkmin_results_equality, + PairwiseDistancesRadiusNeighborhood: assert_radius_neighborhood_results_equality, } @@ -148,12 +166,62 @@ def test_argkmin_factory_method_wrong_usages(): ) +def test_radius_neighborhood_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + radius = 5 + metric = "euclidean" + + with pytest.raises( + ValueError, + match=( + "Only 64bit float datasets are supported at this time, " + "got: X.dtype=float32 and Y.dtype=float64" + ), + ): + PairwiseDistancesRadiusNeighborhood.compute( + X=X.astype(np.float32), Y=Y, radius=radius, metric=metric + ) + + with pytest.raises( + ValueError, + match=( + "Only 64bit float datasets are supported at this time, " + "got: X.dtype=float64 and Y.dtype=int32" + ), + ): + PairwiseDistancesRadiusNeighborhood.compute( + X=X, Y=Y.astype(np.int32), radius=radius, metric=metric + ) + + with pytest.raises(ValueError, match="radius == -1.0, must be >= 0."): + PairwiseDistancesRadiusNeighborhood.compute(X=X, Y=Y, radius=-1, metric=metric) + + with pytest.raises(ValueError, match="Unrecognized metric"): + PairwiseDistancesRadiusNeighborhood.compute( + X=X, Y=Y, radius=radius, metric="wrong metric" + ) + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + PairwiseDistancesRadiusNeighborhood.compute( + X=np.array([1.0, 2.0]), Y=Y, radius=radius, metric=metric + ) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + PairwiseDistancesRadiusNeighborhood.compute( + X=np.asfortranarray(X), Y=Y, radius=radius, metric=metric + ) + + @pytest.mark.parametrize("seed", range(5)) @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("chunk_size", [50, 512, 1024]) @pytest.mark.parametrize( "PairwiseDistancesReduction", - [PairwiseDistancesArgKmin], + [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) def test_chunk_size_agnosticism( PairwiseDistancesReduction, @@ -199,7 +267,7 @@ def test_chunk_size_agnosticism( @pytest.mark.parametrize("chunk_size", [50, 512, 1024]) @pytest.mark.parametrize( "PairwiseDistancesReduction", - [PairwiseDistancesArgKmin], + [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) def test_n_threads_agnosticism( PairwiseDistancesReduction, @@ -241,7 +309,7 @@ def test_n_threads_agnosticism( @pytest.mark.parametrize("metric", PairwiseDistancesReduction.valid_metrics()) @pytest.mark.parametrize( "PairwiseDistancesReduction", - [PairwiseDistancesArgKmin], + [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) def test_strategies_consistency( PairwiseDistancesReduction, From 37b61e3531a56376bc9c699490511284362a336e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 28 Jan 2022 15:56:24 +0100 Subject: [PATCH 05/30] Plug PairwiseDistancesRadiusNeighborhood as a back-end Also move the error message upfront if results have to be sorted without the distances being returned. --- sklearn/neighbors/_base.py | 47 +++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index bcf448ae65c05..1120a63625420 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -22,6 +22,7 @@ from ..base import BaseEstimator, MultiOutputMixin from ..base import is_classifier from ..metrics import pairwise_distances_chunked +from ..metrics._pairwise_distances_reduction import PairwiseDistancesRadiusNeighborhood from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS from ..utils import ( check_array, @@ -1061,25 +1062,53 @@ class from an array representing our data set and ask who's """ check_is_fitted(self) - if X is not None: - query_is_train = False + if sort_results and not return_distance: + raise ValueError("return_distance must be True if sort_results is True.") + + query_is_train = X is None + if query_is_train: + X = self._fit_X + else: if self.metric == "precomputed": X = _check_precomputed(X) else: - X = self._validate_data(X, accept_sparse="csr", reset=False) - else: - query_is_train = True - X = self._fit_X + X = self._validate_data(X, accept_sparse="csr", reset=False, order="C") if radius is None: radius = self.radius - if self._fit_method == "brute" and self.metric == "precomputed" and issparse(X): + use_pairwise_distances_reductions = ( + self._fit_method == "brute" + and PairwiseDistancesRadiusNeighborhood.is_usable_for( + X if X is not None else self._fit_X, self._fit_X, self.effective_metric_ + ) + ) + + if use_pairwise_distances_reductions: + results = PairwiseDistancesRadiusNeighborhood.compute( + X=X, + Y=self._fit_X, + radius=radius, + metric=self.effective_metric_, + metric_kwargs=self.effective_metric_params_, + n_threads=self.n_jobs, + strategy="auto", + return_distance=return_distance, + sort_results=sort_results, + ) + + elif ( + self._fit_method == "brute" and self.metric == "precomputed" and issparse(X) + ): results = _radius_neighbors_from_graph( X, radius=radius, return_distance=return_distance ) elif self._fit_method == "brute": + # TODO: should no longer be needed once we have Cython-optimized + # implementation for radius queries, with support for sparse and/or + # float32 inputs. + # for efficiency, use squared euclidean distances if self.effective_metric_ == "euclidean": radius *= radius @@ -1113,10 +1142,6 @@ class from an array representing our data set and ask who's results = _to_object_array(neigh_ind_list) if sort_results: - if not return_distance: - raise ValueError( - "return_distance must be True if sort_results is True." - ) for ii in range(len(neigh_dist)): order = np.argsort(neigh_dist[ii], kind="mergesort") neigh_ind[ii] = neigh_ind[ii][order] From 390f62488edf386a7ec7d7916cc5d9e24d12b530 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 22 Feb 2022 11:05:45 +0100 Subject: [PATCH 06/30] TST Check consistency of backends' results --- sklearn/neighbors/tests/test_neighbors.py | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index f6e44b75f1ec2..32a6ade3e3436 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -25,6 +25,9 @@ from sklearn.exceptions import NotFittedError from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics.tests.test_dist_metrics import BOOL_METRICS +from sklearn.metrics.tests.test_pairwise_distances_reduction import ( + assert_radius_neighborhood_results_equality, +) from sklearn.model_selection import cross_val_score from sklearn.model_selection import train_test_split from sklearn.neighbors import VALID_METRICS_SPARSE @@ -2033,3 +2036,58 @@ def test_neighbors_distance_metric_deprecation(): dist_metric = DistanceMetric.get_metric("euclidean") assert isinstance(dist_metric, ActualDistanceMetric) + + +# TODO: Remove filterwarnings in 1.3 when wminkowski is removed +@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") +@pytest.mark.parametrize( + "metric", sorted(set(neighbors.VALID_METRICS["brute"]) - set(["precomputed"])) +) +def test_radius_neighbors_brute_backend( + metric, n_samples=2000, n_features=30, n_query_pts=100, n_neighbors=5 +): + # Both backend for the 'brute' algorithm of radius_neighbors + # must give identical results. + X_train = rng.rand(n_samples, n_features) + X_test = rng.rand(n_query_pts, n_features) + + # Haversine distance only accepts 2D data + if metric == "haversine": + feature_sl = slice(None, 2) + X_train = np.ascontiguousarray(X_train[:, feature_sl]) + X_test = np.ascontiguousarray(X_test[:, feature_sl]) + + metric_params_list = _generate_test_params_for(metric, n_features) + + # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 + ExceptionToAssert = None + if metric == "wminkowski" and sp_version >= parse_version("1.6.0"): + ExceptionToAssert = FutureWarning + + for metric_params in metric_params_list: + p = metric_params.pop("p", 2) + + neigh = neighbors.NearestNeighbors( + n_neighbors=n_neighbors, + algorithm="brute", + metric=metric, + p=p, + metric_params=metric_params, + ) + + neigh.fit(X_train) + with pytest.warns(ExceptionToAssert): + with config_context(enable_cython_pairwise_dist=False): + # Use the legacy backend for brute + legacy_brute_dst, legacy_brute_idx = neigh.radius_neighbors( + X_test, return_distance=True + ) + with config_context(enable_cython_pairwise_dist=True): + # Use the PairwiseDistancesReduction as a backend for brute + pdr_brute_dst, pdr_brute_idx = neigh.radius_neighbors( + X_test, return_distance=True + ) + + assert_radius_neighborhood_results_equality( + legacy_brute_dst, pdr_brute_dst, legacy_brute_idx, pdr_brute_idx + ) From c4d8c4a4fafaf33be312cefb93b8bf0753ee6e91 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 22 Feb 2022 11:08:41 +0100 Subject: [PATCH 07/30] MAINT Group vector fixtures together --- .../metrics/_pairwise_distances_reduction.pyx | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 8f27906283fde..8f4e505f7441e 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -108,30 +108,6 @@ cdef class StdVectorSentinelITYPE(StdVectorSentinel): return sentinel -cpdef DTYPE_t[::1] _sqeuclidean_row_norms( - const DTYPE_t[:, ::1] X, - ITYPE_t num_threads, -): - """Compute the squared euclidean norm of the rows of X in parallel. - - This is faster than using np.einsum("ij, ij->i") even when using a single thread. - """ - cdef: - # Casting for X to remove the const qualifier is needed because APIs - # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' - # const qualifier. - # See: https://github.com/scipy/scipy/issues/14262 - DTYPE_t * X_ptr = &X[0, 0] - ITYPE_t idx = 0 - ITYPE_t n = X.shape[0] - ITYPE_t d = X.shape[1] - DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) - - for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): - squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) - - return squared_row_norms - cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr): """Create a numpy ndarray given a C++ vector. @@ -176,6 +152,31 @@ cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( ##################### +cpdef DTYPE_t[::1] _sqeuclidean_row_norms( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + DTYPE_t * X_ptr = &X[0, 0] + ITYPE_t idx = 0 + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) + + for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): + squared_row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + + return squared_row_norms + +##################### cdef class PairwiseDistancesReduction: """Abstract base class for pairwise distance computation & reduction. From f091804c420e44b11caf907373bc73bc52ad6320 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 23 Feb 2022 11:58:32 +0100 Subject: [PATCH 08/30] DOC Update whats_new entry --- doc/whats_new/v1.1.rst | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index b0d36364ec333..ca5f72f5ad097 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -74,8 +74,8 @@ Changelog - |Efficiency| Low-level routines for reductions on pairwise distances for dense float64 datasets have been refactored. The following functions - and estimators now benefit from improved performances, in particular on - multi-cores machines: + and estimators now benefit from improved performances in terms of hardware + scalability and speed-ups: - :func:`sklearn.metrics.pairwise_distances_argmin` - :func:`sklearn.metrics.pairwise_distances_argmin_min` - :class:`sklearn.cluster.AffinityPropagation` @@ -86,6 +86,8 @@ Changelog - :func:`sklearn.feature_selection.mutual_info_regression` - :class:`sklearn.neighbors.KNeighborsClassifier` - :class:`sklearn.neighbors.KNeighborsRegressor` + - :class:`sklearn.neighbors.RadiusNeighborsClassifier` + - :class:`sklearn.neighbors.RadiusNeighborsRegressor` - :class:`sklearn.neighbors.LocalOutlierFactor` - :class:`sklearn.neighbors.NearestNeighbors` - :class:`sklearn.manifold.Isomap` @@ -95,10 +97,11 @@ Changelog - :class:`sklearn.semi_supervised.LabelPropagation` - :class:`sklearn.semi_supervised.LabelSpreading` - For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors` + For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors` and + :class:`sklearn.neighbors.NearestNeighbors.kneighbors` can be up to ×20 faster than in the previous versions'. - :pr:`21987`, :pr:`22064`, :pr:`22065` and :pr:`22288` + :pr:`21987`, :pr:`22064`, :pr:`22065`, :pr:`22288` and :pr:`22320` by :user:`Julien Jerphanion ` - |Enhancement| All scikit-learn models now generate a more informative From e477c10379c0cd74080dc1cb38cfe9abc0f61000 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 24 Feb 2022 19:25:07 +0100 Subject: [PATCH 09/30] Trigger full [cd build] From 494a0b54d001633f81705fbe86317b1ee8d4b83c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 24 Feb 2022 22:02:55 +0100 Subject: [PATCH 10/30] Reformat and fix typos Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction.pyx | 36 +++++++++---------- sklearn/neighbors/tests/test_neighbors.py | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 8f4e505f7441e..a8459006ae292 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -1243,7 +1243,7 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # # Though it is possible to access their buffer address with # std::vector::data, they can't be stolen: buffers lifetime - # is tight to their std::vector and are deallocated when + # is tied to their std::vector and are deallocated when # std::vectors are. # # To solve this, we dynamically allocate std::vectors and then @@ -1344,29 +1344,29 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): Returns ------- - If return_distance=False: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. + If return_distance=False: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. - If return_distance=True: - - neighbors_indices : ndarray of n_samples_X ndarray - Indices of the neighbors for each vector in X. - - neighbors_distances : ndarray of n_samples_X ndarray - Distances to the neighbors for each vector in X. + If return_distance=True: + - neighbors_indices : ndarray of n_samples_X ndarray + Indices of the neighbors for each vector in X. + - neighbors_distances : ndarray of n_samples_X ndarray + Distances to the neighbors for each vector in X. Notes ----- - This public classmethod is responsible for introspecting the arguments - values to dispatch to the private - :meth:`PairwiseDistancesRadiusNeighborhood._compute` instance method of - the most appropriate :class:`PairwiseDistancesRadiusNeighborhood` - concrete implementation. + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private + :meth:`PairwiseDistancesRadiusNeighborhood._compute` instance method of + the most appropriate :class:`PairwiseDistancesRadiusNeighborhood` + concrete implementation. - All temporarily allocated datastructures necessary for the concrete - implementation are therefore freed when this classmethod returns. + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. - This allows entirely decoupling the interface entirely from the - implementation details whilst maintaining RAII. + This allows entirely decoupling the interface entirely from the + implementation details whilst maintaining RAII. """ # Note (jjerphan): Some design thoughts for future extensions. # This factory comes to handle specialisations for the given arguments. diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 963682cbc2156..5ca2ced6e4fdc 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2046,7 +2046,7 @@ def test_neighbors_distance_metric_deprecation(): def test_radius_neighbors_brute_backend( metric, n_samples=2000, n_features=30, n_query_pts=100, n_neighbors=5 ): - # Both backend for the 'brute' algorithm of radius_neighbors + # Both backends for the 'brute' algorithm of radius_neighbors # must give identical results. X_train = rng.rand(n_samples, n_features) X_test = rng.rand(n_query_pts, n_features) From 72945278f4067c2ac5e0f23db7c32dd98ac9ec5b Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 24 Feb 2022 22:06:38 +0100 Subject: [PATCH 11/30] MAINT Use proper naming --- .../metrics/tests/test_pairwise_distances_reduction.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 31970155d1a71..f864675e81d79 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -30,7 +30,7 @@ ] -def _get_dummy_metric_params_list(metric: str, n_features: int): +def _get_metric_params_list(metric: str, n_features: int): """Return list of dummy DistanceMetric kwargs for tests.""" # Distinguishing on cases not to compute unneeded datastructures. @@ -345,7 +345,7 @@ def test_strategies_consistency( parameter, metric=metric, # Taking the first - metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + metric_kwargs=_get_metric_params_list(metric, n_features)[0], # To be sure to use parallelization chunk_size=n_samples // 4, strategy="parallel_on_X", @@ -358,7 +358,7 @@ def test_strategies_consistency( parameter, metric=metric, # Taking the first - metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + metric_kwargs=_get_metric_params_list(metric, n_features)[0], # To be sure to use parallelization chunk_size=n_samples // 4, strategy="parallel_on_Y", @@ -400,7 +400,7 @@ def test_pairwise_distances_argkmin( X = np.ascontiguousarray(X[:, :2]) Y = np.ascontiguousarray(Y[:, :2]) - metric_kwargs = _get_dummy_metric_params_list(metric, n_features)[0] + metric_kwargs = _get_metric_params_list(metric, n_features)[0] # Reference for argkmin results if metric == "euclidean": From 2ffee2b9ad9bd9f105e652a1a309453a6d7dca43 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 24 Feb 2022 22:26:37 +0100 Subject: [PATCH 12/30] TST Test against a simple high-level implementation" Co-authored-by: Olivier Grisel --- .../test_pairwise_distances_reduction.py | 71 ++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index f864675e81d79..d2db5dfefe0f3 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -373,7 +373,7 @@ def test_strategies_consistency( ) -# Concrete PairwiseDistancesReductions tests +# "Concrete PairwiseDistancesReductions"-specific tests # TODO: Remove filterwarnings in 1.3 when wminkowski is removed @pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") @@ -434,6 +434,75 @@ def test_pairwise_distances_argkmin( ) +# TODO: Remove filterwarnings in 1.3 when wminkowski is removed +@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn") +@pytest.mark.parametrize("n_features", [50, 500]) +@pytest.mark.parametrize("translation", [0, 1e6]) +@pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) +@pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) +def test_pairwise_distances_radius_neighbors( + n_features, + translation, + metric, + strategy, + n_samples=100, + dtype=np.float64, +): + rng = np.random.RandomState(0) + spread = 1000 + radius = spread * np.log(n_features) + X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread + Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread + + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) + Y = np.ascontiguousarray(Y[:, :2]) + + metric_kwargs = _get_metric_params_list(metric, n_features)[0] + + # Reference for argkmin results + if metric == "euclidean": + # Compare to scikit-learn GEMM optimized implementation + dist_matrix = euclidean_distances(X, Y) + else: + dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs) + + # Getting the neighbors for a given radius + neigh_indices_ref = [] + neigh_distances_ref = [] + + for row in dist_matrix: + ind = np.arange(row.shape[0])[row <= radius] + dist = row[ind] + + sort = np.argsort(dist) + ind, dist = ind[sort], dist[sort] + + neigh_indices_ref.append(ind) + neigh_distances_ref.append(dist) + + neigh_indices_ref = np.array(neigh_indices_ref) + neigh_distances_ref = np.array(neigh_distances_ref) + + neigh_distances, neigh_indices = PairwiseDistancesRadiusNeighborhood.compute( + X, + Y, + radius, + metric=metric, + metric_kwargs=metric_kwargs, + return_distance=True, + # So as to have more than a chunk, forcing parallelism. + chunk_size=n_samples // 4, + strategy=strategy, + sort_results=True, + ) + + ASSERT_RESULT[PairwiseDistancesRadiusNeighborhood]( + neigh_distances, neigh_distances_ref, neigh_indices, neigh_indices_ref + ) + + @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("n_features", [5, 10, 100]) From 7c5e6e9e0b644a76e05815af7a90bae50adbb16a Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 24 Feb 2022 22:29:15 +0100 Subject: [PATCH 13/30] DOC Fix whats_new entry --- doc/whats_new/v1.1.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index ca5f72f5ad097..51ff641d920f1 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -98,8 +98,8 @@ Changelog - :class:`sklearn.semi_supervised.LabelSpreading` For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors` and - :class:`sklearn.neighbors.NearestNeighbors.kneighbors` - can be up to ×20 faster than in the previous versions'. + :class:`sklearn.neighbors.NearestNeighbors.radius_neighbors` + can respectively be up to ×20 and ×5 faster than previously. :pr:`21987`, :pr:`22064`, :pr:`22065`, :pr:`22288` and :pr:`22320` by :user:`Julien Jerphanion ` From adda9defd3f491e820cf51a8e237314ce554b42b Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 24 Feb 2022 22:31:07 +0100 Subject: [PATCH 14/30] MAINT Remove n_threads Catching up with #22593. Co-authored-by: Olivier Grisel --- .../metrics/_pairwise_distances_reduction.pyx | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index cdeeda81ab3a0..0d33d9a49a3d1 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -1235,7 +1235,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): str metric="euclidean", chunk_size=None, dict metric_kwargs=None, - n_threads=None, str strategy=None, bint return_distance=False, bint sort_results=False, @@ -1266,15 +1265,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): metric_kwargs : dict, default=None Keyword arguments to pass to specified metric function. - n_threads : int, default=None - The number of OpenMP threads to use for the reduction. - Parallelism is done on chunks and the sharding of chunks - depends on the `strategy` set on - :meth:`~PairwiseDistancesRadiusNeighborhood.compute`. - - See _openmp_effective_n_threads, for details about - the specification of n_threads. - strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None The chunking strategy defining which dataset parallelization are made on. @@ -1356,7 +1346,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): use_squared_distances=use_squared_distances, chunk_size=chunk_size, metric_kwargs=metric_kwargs, - n_threads=n_threads, strategy=strategy, sort_results=sort_results, ) @@ -1368,7 +1357,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): radius=radius, chunk_size=chunk_size, metric_kwargs=metric_kwargs, - n_threads=n_threads, strategy=strategy, sort_results=sort_results, ) @@ -1389,7 +1377,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): DatasetsPair datasets_pair, DTYPE_t radius, chunk_size=None, - n_threads=None, strategy=None, sort_results=False, metric_kwargs=None, @@ -1397,7 +1384,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): super().__init__( datasets_pair=datasets_pair, chunk_size=chunk_size, - n_threads=n_threads, strategy=strategy, ) @@ -1635,7 +1621,6 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad DTYPE_t radius, bint use_squared_distances=False, chunk_size=None, - n_threads=None, strategy=None, sort_results=False, metric_kwargs=None, @@ -1645,7 +1630,6 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), radius=radius, chunk_size=chunk_size, - n_threads=n_threads, strategy=strategy, sort_results=sort_results, metric_kwargs=metric_kwargs, From a823a05967031c710c8273227ec421ddcfe7e578 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 24 Feb 2022 22:36:02 +0100 Subject: [PATCH 15/30] Trigger full [cd build] From 15f38640d39a2bcb5ef1b09cf65045378a7933f0 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 25 Feb 2022 09:27:52 +0100 Subject: [PATCH 16/30] fixup! MAINT Remove n_threads --- sklearn/neighbors/_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 2e1de1b8e5042..10e99a34e6497 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -1100,7 +1100,6 @@ class from an array representing our data set and ask who's radius=radius, metric=self.effective_metric_, metric_kwargs=self.effective_metric_params_, - n_threads=self.n_jobs, strategy="auto", return_distance=return_distance, sort_results=sort_results, From 49a9c875a22be6cd375f1af07ee652bd4a247f38 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 25 Feb 2022 09:28:04 +0100 Subject: [PATCH 17/30] Trigger full [cd build] From 2d946092e1580572731608d003b5cc54b317caba Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 25 Feb 2022 10:52:10 +0100 Subject: [PATCH 18/30] LINT Remove extra spaces --- sklearn/neighbors/tests/test_neighbors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index 5ca2ced6e4fdc..e7ee8a507838e 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -2092,7 +2092,7 @@ def test_radius_neighbors_brute_backend( legacy_brute_dst, pdr_brute_dst, legacy_brute_idx, pdr_brute_idx ) - + def test_valid_metrics_has_no_duplicate(): for val in neighbors.VALID_METRICS.values(): assert len(val) == len(set(val)) From c5fb6d4a9afabc4de1a268a4598127218179a147 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 25 Feb 2022 10:52:17 +0100 Subject: [PATCH 19/30] Trigger full [cd build] From 30f6826d3f7044b29997899b10930d5e42148abc Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 26 Feb 2022 20:47:45 +0100 Subject: [PATCH 20/30] DOC Remove n_threads from docstring Co-authored-by: Thomas J. Fan --- sklearn/metrics/_pairwise_distances_reduction.pyx | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 0d33d9a49a3d1..8b28ea3beec66 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -1181,15 +1181,6 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): scikit-learn configuration for `pairwise_dist_chunk_size`, and use 256 if it is not set. - n_threads: int, default=None - The number of OpenMP threads to use for the reduction. - Parallelism is done on chunks and the sharding of chunks - depends on the `strategy` set on - :meth:`~PairwiseDistancesRadiusNeighborhood.compute`. - - See _openmp_effective_n_threads, for details about - the specification of n_threads. - radius: float The radius defining the neighborhood. """ From 05adeb0f1f2ef78b1573c2fb2c52c2762c649b31 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 26 Feb 2022 21:18:56 +0100 Subject: [PATCH 21/30] Use shared pointers instead of raw pointers Co-authored-by: Thomas J. Fan --- .../metrics/_pairwise_distances_reduction.pyx | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 8b28ea3beec66..102748249c8b7 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -20,6 +20,7 @@ import warnings from .. import get_config from libc.stdlib cimport free, malloc from libc.float cimport DBL_MAX +from libcpp.memory cimport shared_ptr, make_shared from libcpp.vector cimport vector from cython cimport final from cython.operator cimport dereference as deref @@ -137,7 +138,7 @@ cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr): cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( - vector_vector_DITYPE_t* vecs + shared_ptr[vector_vector_DITYPE_t] vecs ): """Coerce a std::vector of std::vector to a ndarray of ndarray.""" cdef: @@ -1208,12 +1209,16 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # To solve this, we dynamically allocate std::vectors and then # encapsulate them in a StdVectorSentinel responsible for # freeing them when the associated np.ndarray is freed. - vector[vector[ITYPE_t]] * neigh_indices - vector[vector[DTYPE_t]] * neigh_distances + # + # Shared pointers (defined via shared_ptr) are use for safer memory management. + # Unique pointers (defined via unique_ptr) can't be used as datastructures + # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. + shared_ptr[vector[vector[ITYPE_t]]] neigh_indices + shared_ptr[vector[vector[DTYPE_t]]] neigh_distances # Used as array of pointers to private datastructures used in threads. - vector[vector[ITYPE_t]] ** neigh_indices_chunks - vector[vector[DTYPE_t]] ** neigh_distances_chunks + vector[shared_ptr[vector[vector[ITYPE_t]]]] neigh_indices_chunks + vector[shared_ptr[vector[vector[DTYPE_t]]]] neigh_distances_chunks bint sort_results @@ -1391,30 +1396,17 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # - when parallelizing on Y, the pointers of those heaps are referencing # std::vectors of std::vectors which are thread-wise-allocated and whose # content will be merged into self.neigh_distances and self.neigh_indices. - self.neigh_distances_chunks = malloc( - sizeof(self.neigh_distances) * self.chunks_n_threads + self.neigh_distances_chunks = vector[shared_ptr[vector[vector[DTYPE_t]]]]( + self.chunks_n_threads ) - self.neigh_indices_chunks = malloc( - sizeof(self.neigh_indices) * self.chunks_n_threads + self.neigh_indices_chunks = vector[shared_ptr[vector[vector[ITYPE_t]]]]( + self.chunks_n_threads ) # Temporary datastructures which will be coerced to numpy arrays on before # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. - self.neigh_distances = new vector[vector[DTYPE_t]](self.n_samples_X) - self.neigh_indices = new vector[vector[ITYPE_t]](self.n_samples_X) - - def __dealloc__(self): - if self.neigh_distances_chunks is not NULL: - free(self.neigh_distances_chunks) - - if self.neigh_indices_chunks is not NULL: - free(self.neigh_indices_chunks) - - if self.neigh_indices is not NULL: - del self.neigh_indices - - if self.neigh_distances is not NULL: - del self.neigh_distances + self.neigh_distances = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) + self.neigh_indices = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) cdef void _compute_and_reduce_distances_on_chunks( self, @@ -1487,8 +1479,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # conditions: each thread has its own vectors of n_samples_X vectors which are # then merged back in the main n_samples_X vectors. for thread_num in range(self.chunks_n_threads): - self.neigh_distances_chunks[thread_num] = new vector[vector[DTYPE_t]](self.n_samples_X) - self.neigh_indices_chunks[thread_num] = new vector[vector[ITYPE_t]](self.n_samples_X) + self.neigh_distances_chunks[thread_num] = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) + self.neigh_indices_chunks[thread_num] = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) @final cdef void _merge_vectors( @@ -1541,9 +1533,9 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # The content of the vector have been std::moved. # Hence they can't be used anymore and can only be deleted. - for thread_num in prange(self.chunks_n_threads, schedule='static'): - del self.neigh_distances_chunks[thread_num] - del self.neigh_indices_chunks[thread_num] + # for thread_num in prange(self.chunks_n_threads, schedule='static'): + # del self.neigh_distances_chunks[thread_num] + # del self.neigh_indices_chunks[thread_num] # Sort in parallel in ascending order w.r.t the distances if requested. if self.sort_results: From 086fb78fc4123ae5db7f279d6ccc900f6058a39f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sat, 26 Feb 2022 21:23:54 +0100 Subject: [PATCH 22/30] fixup! Use shared pointers instead of raw pointers --- sklearn/metrics/_pairwise_distances_reduction.pyx | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 102748249c8b7..73c5412d75ad3 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -1532,10 +1532,9 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): self._merge_vectors(idx, self.chunks_n_threads) # The content of the vector have been std::moved. - # Hence they can't be used anymore and can only be deleted. - # for thread_num in prange(self.chunks_n_threads, schedule='static'): - # del self.neigh_distances_chunks[thread_num] - # del self.neigh_indices_chunks[thread_num] + # Hence they can't be used anymore and can be deleted. + # Their deletion is carried out automatically as the + # implementation relies on shared pointers. # Sort in parallel in ascending order w.r.t the distances if requested. if self.sort_results: From 708e7341c2527d6f772ff12fec82c9b32a6270d1 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sat, 26 Feb 2022 23:49:23 -0500 Subject: [PATCH 23/30] BLD Add compiler args for C++11 --- sklearn/metrics/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index e03282af141f4..736ba6d7d4424 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -32,6 +32,7 @@ def configuration(parent_package="", top_path=None): include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")], language="c++", libraries=libraries, + extra_compile_args=["-std=c++11"], ) config.add_subpackage("tests") From 84fb3214e9e38499392f4863d9c1df30e97f4b8a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 27 Feb 2022 00:03:25 -0500 Subject: [PATCH 24/30] ENH Uses unique pointers --- .../metrics/_pairwise_distances_reduction.pyx | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 73c5412d75ad3..efeb917505b12 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -20,7 +20,7 @@ import warnings from .. import get_config from libc.stdlib cimport free, malloc from libc.float cimport DBL_MAX -from libcpp.memory cimport shared_ptr, make_shared +from libcpp.memory cimport shared_ptr, make_shared, unique_ptr, make_unique from libcpp.vector cimport vector from cython cimport final from cython.operator cimport dereference as deref @@ -138,7 +138,7 @@ cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr): cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( - shared_ptr[vector_vector_DITYPE_t] vecs + unique_ptr[vector_vector_DITYPE_t] vecs ): """Coerce a std::vector of std::vector to a ndarray of ndarray.""" cdef: @@ -1213,8 +1213,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # Shared pointers (defined via shared_ptr) are use for safer memory management. # Unique pointers (defined via unique_ptr) can't be used as datastructures # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. - shared_ptr[vector[vector[ITYPE_t]]] neigh_indices - shared_ptr[vector[vector[DTYPE_t]]] neigh_distances + unique_ptr[vector[vector[ITYPE_t]]] neigh_indices + unique_ptr[vector[vector[DTYPE_t]]] neigh_distances # Used as array of pointers to private datastructures used in threads. vector[shared_ptr[vector[vector[ITYPE_t]]]] neigh_indices_chunks @@ -1405,8 +1405,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # Temporary datastructures which will be coerced to numpy arrays on before # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. - self.neigh_distances = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) - self.neigh_indices = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) + self.neigh_distances = make_unique[vector[vector[DTYPE_t]]](self.n_samples_X) + self.neigh_indices = make_unique[vector[vector[ITYPE_t]]](self.n_samples_X) cdef void _compute_and_reduce_distances_on_chunks( self, @@ -1448,8 +1448,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # As this strategy is embarrassingly parallel, we can set the # thread vectors' pointers to the main vectors'. - self.neigh_distances_chunks[thread_num] = self.neigh_distances - self.neigh_indices_chunks[thread_num] = self.neigh_indices + self.neigh_distances_chunks[thread_num] = shared_ptr[vector[vector[DTYPE_t]]](self.neigh_distances) + self.neigh_indices_chunks[thread_num] = shared_ptr[vector[vector[ITYPE_t]]](self.neigh_indices) @final cdef void _parallel_on_X_prange_iter_finalize( From 970d4ddbc428f23f4de0ef25a29a58eb0f0bb68b Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sun, 27 Feb 2022 08:26:11 +0100 Subject: [PATCH 25/30] DOC Update comment for unique pointers --- sklearn/metrics/_pairwise_distances_reduction.pyx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index efeb917505b12..2909f30530469 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -1211,8 +1211,9 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # freeing them when the associated np.ndarray is freed. # # Shared pointers (defined via shared_ptr) are use for safer memory management. - # Unique pointers (defined via unique_ptr) can't be used as datastructures - # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. + # Unique pointers (defined via unique_ptr) are also used for the main + # datastructures and are cast as shared pointers across threads for + # parallel_on_X; see _parallel_on_X_init_chunk. unique_ptr[vector[vector[ITYPE_t]]] neigh_indices unique_ptr[vector[vector[DTYPE_t]]] neigh_distances From 6e617735f40f8dfadb3eaf1f51275627cf2b0d40 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Sun, 27 Feb 2022 08:52:07 +0100 Subject: [PATCH 26/30] Revert "ENH Uses unique pointers" --- .../metrics/_pairwise_distances_reduction.pyx | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 2909f30530469..73c5412d75ad3 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -20,7 +20,7 @@ import warnings from .. import get_config from libc.stdlib cimport free, malloc from libc.float cimport DBL_MAX -from libcpp.memory cimport shared_ptr, make_shared, unique_ptr, make_unique +from libcpp.memory cimport shared_ptr, make_shared from libcpp.vector cimport vector from cython cimport final from cython.operator cimport dereference as deref @@ -138,7 +138,7 @@ cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr): cdef np.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( - unique_ptr[vector_vector_DITYPE_t] vecs + shared_ptr[vector_vector_DITYPE_t] vecs ): """Coerce a std::vector of std::vector to a ndarray of ndarray.""" cdef: @@ -1211,11 +1211,10 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # freeing them when the associated np.ndarray is freed. # # Shared pointers (defined via shared_ptr) are use for safer memory management. - # Unique pointers (defined via unique_ptr) are also used for the main - # datastructures and are cast as shared pointers across threads for - # parallel_on_X; see _parallel_on_X_init_chunk. - unique_ptr[vector[vector[ITYPE_t]]] neigh_indices - unique_ptr[vector[vector[DTYPE_t]]] neigh_distances + # Unique pointers (defined via unique_ptr) can't be used as datastructures + # are shared across threads for parallel_on_X; see _parallel_on_X_init_chunk. + shared_ptr[vector[vector[ITYPE_t]]] neigh_indices + shared_ptr[vector[vector[DTYPE_t]]] neigh_distances # Used as array of pointers to private datastructures used in threads. vector[shared_ptr[vector[vector[ITYPE_t]]]] neigh_indices_chunks @@ -1406,8 +1405,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # Temporary datastructures which will be coerced to numpy arrays on before # PairwiseDistancesRadiusNeighborhood.compute "return" and will be then freed. - self.neigh_distances = make_unique[vector[vector[DTYPE_t]]](self.n_samples_X) - self.neigh_indices = make_unique[vector[vector[ITYPE_t]]](self.n_samples_X) + self.neigh_distances = make_shared[vector[vector[DTYPE_t]]](self.n_samples_X) + self.neigh_indices = make_shared[vector[vector[ITYPE_t]]](self.n_samples_X) cdef void _compute_and_reduce_distances_on_chunks( self, @@ -1449,8 +1448,8 @@ cdef class PairwiseDistancesRadiusNeighborhood(PairwiseDistancesReduction): # As this strategy is embarrassingly parallel, we can set the # thread vectors' pointers to the main vectors'. - self.neigh_distances_chunks[thread_num] = shared_ptr[vector[vector[DTYPE_t]]](self.neigh_distances) - self.neigh_indices_chunks[thread_num] = shared_ptr[vector[vector[ITYPE_t]]](self.neigh_indices) + self.neigh_distances_chunks[thread_num] = self.neigh_distances + self.neigh_indices_chunks[thread_num] = self.neigh_indices @final cdef void _parallel_on_X_prange_iter_finalize( From 46f3bc442a1c816c658278ef51da2ed6791b633a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 3 Mar 2022 17:42:08 -0500 Subject: [PATCH 27/30] ENH Uses vectors for Euclidean Radius --- .../metrics/_pairwise_distances_reduction.pyx | 38 ++++--------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 73c5412d75ad3..860a467836bdc 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -1588,7 +1588,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad const DTYPE_t[::1] Y_norm_squared # Buffers for GEMM - DTYPE_t ** dist_middle_terms_chunks + vector[vector[DTYPE_t]] dist_middle_terms_chunks bint use_squared_distances @classmethod @@ -1639,14 +1639,10 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad self.r_radius = radius # Temporary datastructures used in threads - self.dist_middle_terms_chunks = malloc( - sizeof(DTYPE_t *) * self.effective_n_threads + self.dist_middle_terms_chunks = vector[vector[DTYPE_t]]( + self.effective_n_threads ) - def __dealloc__(self): - if self.dist_middle_terms_chunks is not NULL: - free(self.dist_middle_terms_chunks) - @final cdef void compute_exact_distances(self) nogil: if not self.use_squared_distances: @@ -1660,18 +1656,10 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad PairwiseDistancesRadiusNeighborhood._parallel_on_X_parallel_init(self, thread_num) # Temporary buffer for the `-2 * X_c @ Y_c.T` term - self.dist_middle_terms_chunks[thread_num] = malloc( - self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + self.dist_middle_terms_chunks[thread_num].resize( + self.Y_n_samples_chunk * self.X_n_samples_chunk ) - @final - cdef void _parallel_on_X_parallel_finalize( - self, - ITYPE_t thread_num - ) nogil: - PairwiseDistancesRadiusNeighborhood._parallel_on_X_parallel_finalize(self, thread_num) - free(self.dist_middle_terms_chunks[thread_num]) - @final cdef void _parallel_on_Y_init( self, @@ -1681,20 +1669,10 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad for thread_num in range(self.chunks_n_threads): # Temporary buffer for the `-2 * X_c @ Y_c.T` term - self.dist_middle_terms_chunks[thread_num] = malloc( - self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + self.dist_middle_terms_chunks[thread_num].resize( + self.Y_n_samples_chunk * self.X_n_samples_chunk ) - @final - cdef void _parallel_on_Y_finalize( - self, - ) nogil: - cdef ITYPE_t thread_num - PairwiseDistancesRadiusNeighborhood._parallel_on_Y_finalize(self) - - for thread_num in range(self.chunks_n_threads): - free(self.dist_middle_terms_chunks[thread_num]) - @final cdef void _compute_and_reduce_distances_on_chunks( self, @@ -1710,7 +1688,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] - DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() # Careful: LDA, LDB and LDC are given for F-ordered arrays # in BLAS documentations, for instance: From 0a2736b6915da82752b1c2d2edbc42c05fae3724 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 4 Mar 2022 11:37:15 +0100 Subject: [PATCH 28/30] DOC Reference external documentations for StdVectorSentinel Co-authored-by: Thomas J. Fan --- sklearn/metrics/_pairwise_distances_reduction.pyx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx index 860a467836bdc..29ac839187fc9 100644 --- a/sklearn/metrics/_pairwise_distances_reduction.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -93,6 +93,7 @@ cdef class StdVectorSentinelDTYPE(StdVectorSentinel): @staticmethod cdef StdVectorSentinel create_for(vector[DTYPE_t] * vec_ptr): # This initializes the object directly without calling __init__ + # See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa cdef StdVectorSentinelDTYPE sentinel = StdVectorSentinelDTYPE.__new__(StdVectorSentinelDTYPE) sentinel.vec.swap(deref(vec_ptr)) return sentinel @@ -104,6 +105,7 @@ cdef class StdVectorSentinelITYPE(StdVectorSentinel): @staticmethod cdef StdVectorSentinel create_for(vector[ITYPE_t] * vec_ptr): # This initializes the object directly without calling __init__ + # See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa cdef StdVectorSentinelITYPE sentinel = StdVectorSentinelITYPE.__new__(StdVectorSentinelITYPE) sentinel.vec.swap(deref(vec_ptr)) return sentinel @@ -129,8 +131,8 @@ cdef np.ndarray vector_to_nd_array(vector_DITYPE_t * vect_ptr): sentinel = StdVectorSentinelITYPE.create_for(vect_ptr) # Makes the numpy array responsible of the life-cycle of its buffer. - # A reference to the StdVectorSentinel will be stolen by the call bellow, - # so we increase its reference counter. + # A reference to the StdVectorSentinel will be stolen by the call to + # `PyArray_SetBaseObject` below, so we increase its reference counter. # See: https://docs.python.org/3/c-api/intro.html#reference-count-details Py_INCREF(sentinel) np.PyArray_SetBaseObject(arr, sentinel) From bba8f73e0ba52aa8195d02562f9fda8629968337 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 4 Mar 2022 14:15:11 +0100 Subject: [PATCH 29/30] TST Remove useless haversine case --- sklearn/metrics/tests/test_pairwise_distances_reduction.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index fb55e0f6dd613..308eece1fb6df 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -456,11 +456,6 @@ def test_pairwise_distances_radius_neighbors( X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread - # Haversine distance only accepts 2D data - if metric == "haversine": - X = np.ascontiguousarray(X[:, :2]) - Y = np.ascontiguousarray(Y[:, :2]) - metric_kwargs = _get_metric_params_list(metric, n_features)[0] # Reference for argkmin results From afa4e55befba413e3c5c19fb934f5897ff4c705a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 4 Mar 2022 11:47:22 -0500 Subject: [PATCH 30/30] CI [cd build]