diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 8463954782066..ded8c151a2915 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -65,6 +65,34 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +- |Efficiency| Low-level routines for reductions on pairwise distances + for dense float64 datasets have been refactored. The following functions + and estimators now benefit from improved performances, in particular on + multi-cores machines: + - :func:`sklearn.metrics.pairwise_distances_argmin` + - :func:`sklearn.metrics.pairwise_distances_argmin_min` + - :class:`sklearn.cluster.AffinityPropagation` + - :class:`sklearn.cluster.Birch` + - :class:`sklearn.cluster.MeanShift` + - :class:`sklearn.cluster.OPTICS` + - :class:`sklearn.cluster.SpectralClustering` + - :func:`sklearn.feature_selection.mutual_info_regression` + - :class:`sklearn.neighbors.KNeighborsClassifier` + - :class:`sklearn.neighbors.KNeighborsRegressor` + - :class:`sklearn.neighbors.LocalOutlierFactor` + - :class:`sklearn.neighbors.NearestNeighbors` + - :class:`sklearn.manifold.Isomap` + - :class:`sklearn.manifold.LocallyLinearEmbedding` + - :class:`sklearn.manifold.TSNE` + - :func:`sklearn.manifold.trustworthiness` + - :class:`sklearn.semi_supervised.LabelPropagation` + - :class:`sklearn.semi_supervised.LabelSpreading` + + For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors` + can be up to 20× faster than in the previous versions'. + + :pr:`21462` by :user:`Julien Jerphanion `. + - |Enhancement| All scikit-learn models now generate a more informative error message when some input contains unexpected `NaN` or infinite values. In particular the message contains the input name ("X", "y" or diff --git a/sklearn/_config.py b/sklearn/_config.py index c41c180012056..d6a02737f640d 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -9,6 +9,9 @@ "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": int( + os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256) + ), } _threadlocal = threading.local() @@ -40,7 +43,11 @@ def get_config(): def set_config( - assume_finite=None, working_memory=None, print_changed_only=None, display=None + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + pairwise_dist_chunk_size=None, ): """Set global scikit-learn configuration @@ -80,6 +87,12 @@ def set_config( .. versionadded:: 0.23 + pairwise_dist_chunk_size : int, default=None + The number of vectors per chunk for PairwiseDistancesReduction. + Default is 256 (suitable for most of modern laptops' caches and architectures). + + .. versionadded:: 1.1 + See Also -------- config_context : Context manager for global scikit-learn configuration. @@ -95,11 +108,18 @@ def set_config( local_config["print_changed_only"] = print_changed_only if display is not None: local_config["display"] = display + if pairwise_dist_chunk_size is not None: + local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size @contextmanager def config_context( - *, assume_finite=None, working_memory=None, print_changed_only=None, display=None + *, + assume_finite=None, + working_memory=None, + print_changed_only=None, + display=None, + pairwise_dist_chunk_size=None, ): """Context manager for global scikit-learn configuration. @@ -138,6 +158,12 @@ def config_context( .. versionadded:: 0.23 + pairwise_dist_chunk_size : int, default=None + The number of vectors per chunk for PairwiseDistancesReduction. + Default is 256 (suitable for most of modern laptops' caches and architectures). + + .. versionadded:: 1.1 + Yields ------ None. diff --git a/sklearn/metrics/_dist_metrics.pxd b/sklearn/metrics/_dist_metrics.pxd index 611f6759e2c8b..e7c2f2ea2f926 100644 --- a/sklearn/metrics/_dist_metrics.pxd +++ b/sklearn/metrics/_dist_metrics.pxd @@ -64,3 +64,24 @@ cdef class DistanceMetric: cdef DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1 cdef DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1 + + +###################################################################### +# DatasetsPair base class +cdef class DatasetsPair: + cdef DistanceMetric distance_metric + + cdef ITYPE_t n_samples_X(self) nogil + + cdef ITYPE_t n_samples_Y(self) nogil + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil + + +cdef class DenseDenseDatasetsPair(DatasetsPair): + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + ITYPE_t d diff --git a/sklearn/metrics/_dist_metrics.pyx b/sklearn/metrics/_dist_metrics.pyx index 946d8d7735601..9b257058362c8 100644 --- a/sklearn/metrics/_dist_metrics.pyx +++ b/sklearn/metrics/_dist_metrics.pyx @@ -4,6 +4,8 @@ import numpy as np cimport numpy as np +from cython cimport final + np.import_array() # required in order to use C-API @@ -23,10 +25,10 @@ cdef inline np.ndarray _buffer_to_ndarray(const DTYPE_t* x, np.npy_intp n): return PyArray_SimpleNewFromData(1, &n, DTYPECODE, x) -# some handy constants from libc.math cimport fabs, sqrt, exp, pow, cos, sin, asin cdef DTYPE_t INF = np.inf +from scipy.sparse import csr_matrix, issparse from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE from ..utils._typedefs import DTYPE, ITYPE from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper @@ -68,6 +70,17 @@ METRIC_MAPPING = {'euclidean': EuclideanDistance, 'haversine': HaversineDistance, 'pyfunc': PyFuncDistance} +BOOL_METRICS = [ + "hamming", + "matching", + "jaccard", + "dice", + "kulsinski", + "rogerstanimoto", + "russellrao", + "sokalmichener", + "sokalsneath", +] def get_valid_metric_ids(L): """Given an iterable of metric class names or class identifiers, @@ -199,8 +212,8 @@ cdef class DistanceMetric: """ def __cinit__(self): self.p = 2 - self.vec = np.zeros(1, dtype=DTYPE, order='c') - self.mat = np.zeros((1, 1), dtype=DTYPE, order='c') + self.vec = np.zeros(1, dtype=DTYPE, order='C') + self.mat = np.zeros((1, 1), dtype=DTYPE, order='C') self.size = 1 def __reduce__(self): @@ -299,8 +312,9 @@ cdef class DistanceMetric: This can optionally be overridden in a base class. The rank-preserving surrogate distance is any measure that yields the same - rank as the distance, but is more efficient to compute. For example, for the - Euclidean metric, the surrogate distance is the squared-euclidean distance. + rank as the distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. """ return self.dist(x1, x2, size) @@ -336,8 +350,9 @@ cdef class DistanceMetric: """Convert the rank-preserving surrogate distance to the distance. The surrogate distance is any measure that yields the same rank as the - distance, but is more efficient to compute. For example, for the - Euclidean metric, the surrogate distance is the squared-euclidean distance. + distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. Parameters ---------- @@ -355,8 +370,9 @@ cdef class DistanceMetric: """Convert the true distance to the rank-preserving surrogate distance. The surrogate distance is any measure that yields the same rank as the - distance, but is more efficient to compute. For example, for the - Euclidean metric, the surrogate distance is the squared-euclidean distance. + distance, but is more efficient to compute. For example, the + rank-preserving surrogate distance of the Euclidean metric is the + squared-euclidean distance. Parameters ---------- @@ -1191,3 +1207,158 @@ cdef class PyFuncDistance(DistanceMetric): cdef inline double fmax(double a, double b) nogil: return max(a, b) + + +###################################################################### +# Datasets Pair Classes +cdef class DatasetsPair: + """Abstract class which wraps a pair of datasets (X, Y). + + This class allows computing distances between a single pair of rows of + of X and Y at a time given the pair of their indices (i, j). This class is + specialized for each metric thanks to the :func:`get_for` factory classmethod. + + The handling of parallelization over chunks to compute the distances + and aggregation for several rows at a time is done in dedicated + subclasses of PairwiseDistancesReduction that in-turn rely on + subclasses of DatasetsPair for each pair of rows in the data. The goal + is to make it possible to decouple the generic parallelization and + aggregation logic from metric-specific computation as much as + possible. + + X and Y can be stored as np.ndarrays or CSR matrices in subclasses. + + This class avoids the overhead of dispatching distance computations + to :class:`sklearn.metrics.DistanceMetric` based on the physical + representation of the vectors (sparse vs. dense). It makes use of + cython.final to remove the overhead of dispatching method calls. + + Parameters + ---------- + distance_metric: DistanceMetric + The distance metric responsible for computing distances + between two vectors of (X, Y). + """ + + @classmethod + def get_for( + cls, + X, + Y, + str metric="euclidean", + dict metric_kwargs=None, + ) -> DatasetsPair: + """Return the DatasetsPair implementation for the given arguments. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + If provided as a ndarray, it must be C-contiguous. + If provided as a sparse matrix, it must be in CSR format. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + If provided as a ndarray, it must be C-contiguous. + If provided as a sparse matrix, it must be in CSR format. + + metric : str, default='euclidean' + The distance metric to use for argkmin. The default metric is + a fast implementation of the standard Euclidean metric. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + Returns + ------- + datasets_pair: DatasetsPair + The suited DatasetsPair implementation. + """ + cdef: + DistanceMetric distance_metric = DistanceMetric.get_metric( + metric, + **(metric_kwargs or {}) + ) + + if X.dtype != np.float64 or Y.dtype != np.float64: + raise ValueError("Only 64bit float datasets are supported for X and Y.") + + # Metric-specific checks that do not replace nor duplicate `check_array`. + distance_metric._validate_data(X) + distance_metric._validate_data(Y) + + if issparse(X) or issparse(Y): + raise ValueError("Only dense datasets are supported for X and Y.") + + return DenseDenseDatasetsPair(X, Y, distance_metric) + + @classmethod + def unpack_csr_matrix(cls, X: csr_matrix): + """Ensure getting ITYPE instead of int internally used for CSR matrices.""" + X_data = np.asarray(X.data, dtype=DTYPE) + X_indices = np.asarray(X.indices, dtype=ITYPE) + X_indptr = np.asarray(X.indptr, dtype=ITYPE) + return X_data, X_indptr, X_indptr + + def __init__(self, DistanceMetric distance_metric): + self.distance_metric = distance_metric + + cdef ITYPE_t n_samples_X(self) nogil: + """Number of samples in X.""" + return -999 + + cdef ITYPE_t n_samples_Y(self) nogil: + """Number of samples in Y.""" + return -999 + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.dist(i, j) + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: + return -1 + +@final +cdef class DenseDenseDatasetsPair(DatasetsPair): + """Compute distances between vectors of two arrays. + + Parameters + ---------- + X: ndarray of shape (n_samples_X, n_features) + Rows represent vectors. Must be C-contiguous. + + Y: ndarray of shape (n_samples_Y, n_features) + Rows represent vectors. Must be C-contiguous. + + distance_metric: DistanceMetric + The distance metric responsible for computing distances + between two vectors of (X, Y). + """ + + def __init__(self, X, Y, DistanceMetric distance_metric): + super().__init__(distance_metric) + # Arrays have already been checked + self.X = X + self.Y = Y + self.d = X.shape[1] + + @final + cdef ITYPE_t n_samples_X(self) nogil: + return self.X.shape[0] + + @final + cdef ITYPE_t n_samples_Y(self) nogil: + return self.Y.shape[0] + + @final + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.distance_metric.rdist(&self.X[i, 0], + &self.Y[j, 0], + self.d) + + @final + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: + return self.distance_metric.dist(&self.X[i, 0], + &self.Y[j, 0], + self.d) diff --git a/sklearn/metrics/_pairwise_distances_reduction.pyx b/sklearn/metrics/_pairwise_distances_reduction.pyx new file mode 100644 index 0000000000000..e08a63e8fe4ab --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction.pyx @@ -0,0 +1,1031 @@ +# Pairwise Distances Reductions +# ============================= +# +# Author: Julien Jerphanion +# +# +# The routines defined here are used in various algorithms performing +# the same structure of operations on distances between vectors +# of a datasets pair (X, Y). +# +# Importantly, the core of the computation is chunked to make sure that the pairwise +# distance chunk matrices stay in CPU cache before applying the final reduction step. +# Furthermore, the chunking strategy is also used to leverage OpenMP-based parallelism +# (using Cython prange loops) which gives another multiplicative speed-up in +# favorable cases on many-core machines. +cimport numpy as np +import numpy as np +import warnings +import scipy.sparse + +from .. import get_config +from libc.stdlib cimport free, malloc +from libc.float cimport DBL_MAX +from cython cimport final +from cython.parallel cimport parallel, prange + +from ._dist_metrics cimport DatasetsPair, DenseDenseDatasetsPair +from ..utils._cython_blas cimport ( + BLAS_Order, + BLAS_Trans, + ColMajor, + NoTrans, + RowMajor, + Trans, + _dot, + _gemm, +) +from ..utils._heap cimport simultaneous_sort, heap_push +from ..utils._openmp_helpers cimport _openmp_thread_num +from ..utils._typedefs cimport ITYPE_t, DTYPE_t, DITYPE_t + +from numbers import Integral, Real +from typing import List +from scipy.sparse import issparse +from ._dist_metrics import BOOL_METRICS, METRIC_MAPPING +from ..utils import check_scalar, _in_unstable_openblas_configuration +from ..utils.fixes import threadpool_limits +from ..utils._openmp_helpers import _openmp_effective_n_threads +from ..utils._typedefs import ITYPE, DTYPE + +np.import_array() + +cpdef DTYPE_t[::1] _sqeuclidean_row_norms( + const DTYPE_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + DTYPE_t * X_ptr = &X[0, 0] + ITYPE_t idx = 0 + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] row_norms = np.empty(n, dtype=DTYPE) + + for idx in prange(n, schedule='static', nogil=True, num_threads=num_threads): + row_norms[idx] = _dot(d, X_ptr + idx * d, 1, X_ptr + idx * d, 1) + + return row_norms + +cdef class PairwiseDistancesReduction: + """Abstract base class for pairwise distance computation & reduction + + Subclasses of this class compute pairwise distances between a set of + vectors (rows) X and another set of vectors (rows) Y and apply a + reduction on top. The reduction takes a matrix of pairwise distances + between rows of X and Y as input and outputs an aggregate data-structure + for each row of X. The aggregate values are typically smaller than the number + of rows in Y, hence the term reduction. + + For computational reasons, it is interesting to perform the reduction on + the fly on chunks of rows of X and Y so as to keep intermediate + data-structures in CPU cache and avoid unnecessary round trips of large + distance arrays with the RAM that would otherwise severely degrade the + speed by making the overall processing memory-bound. + + The base class provides the generic chunked parallelization template using + OpenMP loops (Cython prange), either on rows of X or rows of Y depending on + their respective sizes. + + The subclasses are specialized for reduction. + + The actual distance computation for a given pair of rows of X and Y are + delegated to format-specific subclasses of the DatasetsPair companion base + class. + + Parameters + ---------- + datasets_pair: DatasetsPair + The pair of dataset to use. + + chunk_size: int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + n_threads: int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on :method:`~PairwiseDistancesReduction.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + Strategies differs on the dispatching they use for chunks on threads: + + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread then iterates on all the chunks of X. This strategy is + embarrassingly parallel but uses intermediate datastructures + synchronisation. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y'. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + """ + + cdef: + readonly DatasetsPair datasets_pair + + # The number of threads that can be used is stored in effective_n_threads. + # + # The number of threads to use in the parallelisation strategy + # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads: + # for small datasets, less threads might be needed to loop over pair of chunks. + # + # Hence the number of threads that _will_ be used for looping over chunks + # is stored in chunks_n_threads, allowing solely using what we need. + # + # Thus, an invariant is: + # + # chunks_n_threads <= effective_n_threads + # + ITYPE_t effective_n_threads + ITYPE_t chunks_n_threads + + ITYPE_t n_samples_chunk, chunk_size + + ITYPE_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_remainder + ITYPE_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_remainder + + bint execute_in_parallel_on_Y + + @classmethod + def valid_metrics(cls) -> List[str]: + excluded = { + "pyfunc", # is relatively slow because we need to coerce data as np arrays + "mahalanobis", # is numerically unstable + # TODO: In order to support discrete distance metrics, we need to have a + # simultaneous sort which breaks ties on indices when distances are identical. + # The best might be using std::stable_sort and a Comparator taking an + # Arrays of Structures instead of Structure of Arrays (currently used). + "hamming", + *BOOL_METRICS, + } + return sorted(set(METRIC_MAPPING.keys()).difference(excluded)) + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + """Return True if the PairwiseDistancesReduction can be used for the given parameters. + + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples_X, n_features) + Input data. + + Y : {ndarray, sparse matrix} of shape (n_samples_Y, n_features) + Input data. + + metric : str, default='euclidean' + The distance metric to use. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + Returns + ------- + True if the PairwiseDistancesReduction can be used, else False. + """ + # Coercing to np.array to get the dtype + # TODO: what is the best way to get lists' dtype? + X = np.asarray(X) if not isinstance(X, (np.ndarray, scipy.sparse.spmatrix)) else X + Y = np.asarray(Y) if not isinstance(Y, (np.ndarray, scipy.sparse.spmatrix)) else Y + # TODO: support sparse arrays and 32 bits + return (not issparse(X) and X.dtype == np.float64 and X.ndim == 2 and + not issparse(Y) and Y.dtype == np.float64 and Y.ndim == 2 and + metric in cls.valid_metrics()) + + def __init__( + self, + DatasetsPair datasets_pair, + chunk_size=None, + n_threads=None, + strategy=None, + ): + cdef: + ITYPE_t n_samples_chunk, X_n_full_chunks, Y_n_full_chunks + + if chunk_size is None: + chunk_size = get_config().get("pairwise_dist_chunk_size", 256) + + self.chunk_size = check_scalar(chunk_size, "chunk_size", Integral, min_val=20) + + self.effective_n_threads = _openmp_effective_n_threads(n_threads) + + self.datasets_pair = datasets_pair + + self.n_samples_X = datasets_pair.n_samples_X() + self.X_n_samples_chunk = min(self.n_samples_X, self.chunk_size) + X_n_full_chunks = self.n_samples_X // self.X_n_samples_chunk + self.X_n_samples_remainder = self.n_samples_X % self.X_n_samples_chunk + + self.n_samples_Y = datasets_pair.n_samples_Y() + self.Y_n_samples_chunk = min(self.n_samples_Y, self.chunk_size) + Y_n_full_chunks = self.n_samples_Y // self.Y_n_samples_chunk + self.Y_n_samples_remainder = self.n_samples_Y % self.Y_n_samples_chunk + + # Counting remainder chunk in total number of chunks + self.X_n_chunks = X_n_full_chunks + (self.X_n_samples_remainder != 0) + self.Y_n_chunks = Y_n_full_chunks + (self.Y_n_samples_remainder != 0) + + if strategy is None: + strategy = get_config().get("pairwise_dist_parallel_strategy", 'auto') + + if strategy not in ('parallel_on_X', 'parallel_on_Y', 'auto'): + raise RuntimeError(f"strategy must be 'parallel_on_X, 'parallel_on_Y', " + f"or 'auto', but currently strategy='{self.strategy}'.") + + if strategy == 'auto': + # This is a simple heuristic whose constant for the + # comparison has been chosen based on experiments. + if 4 * self.chunk_size * self.effective_n_threads < self.n_samples_X: + strategy = 'parallel_on_X' + else: + strategy = 'parallel_on_Y' + + self.execute_in_parallel_on_Y = strategy == "parallel_on_Y" + + # Not using less, not using more. + self.chunks_n_threads = min( + self.Y_n_chunks if self.execute_in_parallel_on_Y else self.X_n_chunks, + self.effective_n_threads, + ) + + @final + cdef void _parallel_on_X(self) nogil: + """Compute the pairwise distances of each vector (row) of X on Y + by parallelizing computation on chunks of X and reduce them. + + This strategy dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + thread_num = _openmp_thread_num() + + # Allocating thread datastructures + self._parallel_on_X_parallel_init(thread_num) + + for X_chunk_idx in prange(self.X_n_chunks, schedule='static'): + X_start = X_chunk_idx * self.X_n_samples_chunk + if (X_chunk_idx == self.X_n_chunks - 1 + and self.X_n_samples_remainder > 0): + X_end = X_start + self.X_n_samples_remainder + else: + X_end = X_start + self.X_n_samples_chunk + + # Reinitializing thread datastructures for the new X chunk + self._parallel_on_X_init_chunk(thread_num, X_start) + + for Y_chunk_idx in range(self.Y_n_chunks): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if (Y_chunk_idx == self.Y_n_chunks - 1 + and self.Y_n_samples_remainder > 0): + Y_end = Y_start + self.Y_n_samples_remainder + else: + Y_end = Y_start + self.Y_n_samples_chunk + + self._compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + + # Adjusting thread datastructures on the full pass on Y + self._parallel_on_X_prange_iter_finalize(thread_num, X_start, X_end) + + # end: for X_chunk_idx + + # Deallocating thread datastructures + self._parallel_on_X_parallel_finalize(thread_num) + + # end: with nogil, parallel + return + + @final + cdef void _parallel_on_Y(self) nogil: + """Compute the pairwise distances of each vector (row) of X on Y + by parallelizing computation on chunks of Y and reduce them. + + This strategy dispatches chunks of Y uniformly on threads. + Each thread then iterates on all the chunks of X. This strategy is + embarrassingly parallel but uses intermediate datastructures + synchronisation. + + Private datastructures are modified internally by threads. + + Private template methods can be implemented on subclasses to + interact with those datastructures at various stages. + """ + cdef: + ITYPE_t Y_start, Y_end, X_start, X_end, X_chunk_idx, Y_chunk_idx + ITYPE_t thread_num + + # Allocating datastructures + self._parallel_on_Y_parallel_init() + + for X_chunk_idx in range(self.X_n_chunks): + X_start = X_chunk_idx * self.X_n_samples_chunk + if X_chunk_idx == self.X_n_chunks - 1 and self.X_n_samples_remainder > 0: + X_end = X_start + self.X_n_samples_remainder + else: + X_end = X_start + self.X_n_samples_chunk + + with nogil, parallel(num_threads=self.chunks_n_threads): + thread_num = _openmp_thread_num() + + # Initializing datastructures used in this thread + self._parallel_on_Y_init(thread_num) + + for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'): + Y_start = Y_chunk_idx * self.Y_n_samples_chunk + if Y_chunk_idx == self.Y_n_chunks - 1 \ + and self.Y_n_samples_remainder > 0: + Y_end = Y_start + self.Y_n_samples_remainder + else: + Y_end = Y_start + self.Y_n_samples_chunk + + self._compute_and_reduce_distances_on_chunks( + X_start, X_end, + Y_start, Y_end, + thread_num, + ) + # end: prange + + # Note: we don't need a _parallel_on_Y_finalize similarly. + # This can be introduced if needed. + + # end: with nogil, parallel + + # Synchronizing the thread datastructures with the main ones + self._parallel_on_Y_synchronize(X_start, X_end) + + # end: for X_chunk_idx + # Deallocating temporary datastructures and adjusting main datastructures + self._parallel_on_Y_finalize() + return + + # Placeholder methods which have to be implemented + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + """Compute the pairwise distances on two chunks of X and Y and reduce them. + + This is the core critical region of PairwiseDistanceReductions' computations + which must be implemented in subclasses. + """ + return + + def _finalize_results(self, bint return_distance): + """Call-back adapting datastructures before returning results. + + This must be implemented in subclasses. + """ + return None + + # Placeholder methods which can be implemented + + cdef void compute_exact_distances(self) nogil: + """Convert rank-preserving distances to exact distances or recompute them.""" + return + + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + """Allocate datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Interact with datastructures after a reduction on chunks.""" + return + + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + """Interact with datastructures after executing all the reductions.""" + return + + cdef void _parallel_on_Y_parallel_init( + self, + ) nogil: + """Allocate datastructures used in all threads.""" + return + + cdef void _parallel_on_Y_init( + self, + ITYPE_t thread_num, + ) nogil: + """Initialise datastructures used in a thread given its number.""" + return + + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + """Update thread datastructures before leaving a parallel region.""" + return + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + """Update datastructures after executing all the reductions.""" + return + +cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction): + """Compute the argkmin of vectors (rows) of X on the ones of Y. + + Parameters + ---------- + datasets_pair: DatasetsPair + The dataset pairs (X, Y) for the reduction. + + k: int + The k for the argkmin reduction. + + chunk_size: int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + n_threads: int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on :method:`~ArgKmin.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + """ + + cdef: + ITYPE_t k + + ITYPE_t[:, ::1] argkmin_indices + DTYPE_t[:, ::1] argkmin_distances + + # Used as array of pointers to private datastructures used in threads. + DTYPE_t ** heaps_r_distances_chunks + ITYPE_t ** heaps_indices_chunks + + @classmethod + def compute( + cls, + X, + Y, + ITYPE_t k, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + n_threads=None, + str strategy=None, + bint return_distance=False, + ): + """Return the results of the reduction for the given arguments. + + Parameters + ---------- + X : ndarray or CSR matrix of shape (n_samples_X, n_features) + Input data. + + Y : ndarray or CSR matrix of shape (n_samples_Y, n_features) + Input data. + + k : int + The k for the argkmin reduction. + + metric : str, default='euclidean' + The distance metric to use for argkmin. + For a list of available metrics, see the documentation of + :class:`~sklearn.metrics.DistanceMetric`. + + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + + n_threads : int, default=None + The number of OpenMP threads to use for the reduction. + Parallelism is done on chunks and the sharding of chunks + depends on the `strategy` set on + :method:`~PairwiseDistancesArgKmin.compute`. + + See _openmp_effective_n_threads, for details about + the specification of n_threads. + + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + + Strategies differs on the dispatching they use for chunks on threads: + + - 'parallel_on__X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures synchronisation. + + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread then iterates on all the chunks of X. This strategy is + embarrassingly parallel but uses intermediate datastructures + synchronisation. + + - 'auto' relies on a simple heuristic to choose between + 'parallel_on__X' and 'parallel_on_Y'. + + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + + return_distance : boolean, default=False + Return distances between each X vector and its + argkmin if set to True. + + Returns + ------- + Indices of argkmin for each vector in X and its associated distances + if return_distance=True. + + Notes + ----- + This public classmethod is responsible for introspecting the arguments + values to dispatch to the private :meth:`PairwiseDistancesArgKmin._compute` + instance method of the most appropriate :class:`PairwiseDistancesArgKmin` + concrete implementation. + + All temporarily allocated datastructures necessary for the concrete + implementation are therefore freed when this classmethod returns. + + This allows entirely decoupling the interface entirely from the + implementation details whilst maintaining RAII. + """ + # Note (jjerphan): Some design thoughts for future extensions. + # This factory comes to handle specialisations for the given arguments. + # For future work, this might can be an entrypoint to specialise operations + # for various back-end and/or hardware and/or datatypes, and/or fused + # {sparse, dense}-datasetspair etc. + if ( + metric in ("euclidean", "sqeuclidean") + and not issparse(X) + and not issparse(Y) + ): + use_squared_distances = metric == "sqeuclidean" + pda = FastEuclideanPairwiseDistancesArgKmin( + X=X, Y=Y, k=k, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + strategy=strategy, + metric_kwargs=metric_kwargs, + ) + else: # Fall back on the default + pda = PairwiseDistancesArgKmin( + datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pda.execute_in_parallel_on_Y: + pda._parallel_on_Y() + else: + pda._parallel_on_X() + + return pda._finalize_results(return_distance) + + def __init__( + self, + DatasetsPair datasets_pair, + ITYPE_t k, + chunk_size=None, + n_threads=None, + strategy=None, + ): + super().__init__(datasets_pair, chunk_size, n_threads, strategy) + + self.k = check_scalar(k, "k", Integral, min_val=1) + + # Allocating pointers to datastructures but not the datastructures themselves. + # There are as many pointers as effective threads. + # + # For the sake of explicitness: + # - when parallelizing on X, those heaps pointers are referencing + # (with proper offsets) addresses of the two main heaps (see bellow) + # - when parallelizing on Y, those heaps pointer heaps are referencing + # small heaps which are thread-wise-allocated and whose content will be + # merged with the main heaps'. + self.heaps_r_distances_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + self.heaps_indices_chunks = malloc( + sizeof(ITYPE_t *) * self.chunks_n_threads + ) + + # Main heaps used by PairwiseDistancesArgKmin._compute to return results. + self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) + self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) + + def __dealloc__(self): + if self.heaps_indices_chunks is not NULL: + free(self.heaps_indices_chunks) + + if self.heaps_r_distances_chunks is not NULL: + free(self.heaps_r_distances_chunks) + + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + ITYPE_t n_samples_X = X_end - X_start + ITYPE_t n_samples_Y = Y_end - Y_start + DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Pushing the distance and their associated indices on heaps + # which keep tracks of the argkmin. + for i in range(n_samples_X): + for j in range(n_samples_Y): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), + Y_start + j, + ) + + @final + cdef void _parallel_on_X_init_chunk( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ) nogil: + # As this strategy is embarrassingly parallel, we can set each + # thread's heaps pointer to the proper position on the main heaps. + self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0] + self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0] + + @final + cdef void _parallel_on_X_prange_iter_finalize( + self, + ITYPE_t thread_num, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx + + # Sorting indices of the argkmin for each query vector of X + for idx in range(X_end - X_start): + simultaneous_sort( + self.heaps_r_distances_chunks[thread_num] + idx * self.k, + self.heaps_indices_chunks[thread_num] + idx * self.k, + self.k + ) + + cdef void _parallel_on_Y_parallel_init( + self, + ) nogil: + cdef: + # Maximum number of scalar elements (the last chunks can be smaller) + ITYPE_t heaps_size = self.X_n_samples_chunk * self.k + ITYPE_t thread_num + + # The allocation is done in parallel for data locality purposes: this way + # the heaps used in each threads are allocated in pages which are closer + # to processor core used by the thread. + for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True, + num_threads=self.chunks_n_threads): + # As chunks of X are shared across threads, so must their + # heaps. To solve this, each thread has its own heaps + # which are then synchronised back in the main ones. + self.heaps_r_distances_chunks[thread_num] = malloc( + heaps_size * sizeof(DTYPE_t) + ) + self.heaps_indices_chunks[thread_num] = malloc( + heaps_size * sizeof(ITYPE_t) + ) + + @final + cdef void _parallel_on_Y_init( + self, + ITYPE_t thread_num, + ) nogil: + # Initialising heaps (memset can't be used here) + for idx in range(self.X_n_samples_chunk * self.k): + self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX + self.heaps_indices_chunks[thread_num][idx] = -1 + + @final + cdef void _parallel_on_Y_synchronize( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ) nogil: + cdef: + ITYPE_t idx, jdx, thread_num + with nogil, parallel(num_threads=self.effective_n_threads): + # Synchronising the thread heaps with the main heaps. + # This is done in parallel sample-wise (no need for locks). + # This might break each thread's data locality a bit but + # but this is negligible and this parallel pattern has + # shown to be efficient in practice. + for idx in prange(X_end - X_start, schedule="static"): + for thread_num in range(self.chunks_n_threads): + for jdx in range(self.k): + heap_push( + &self.argkmin_distances[X_start + idx, 0], + &self.argkmin_indices[X_start + idx, 0], + self.k, + self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], + self.heaps_indices_chunks[thread_num][idx * self.k + jdx], + ) + + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef: + ITYPE_t idx, thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + # Deallocating temporary datastructures + for thread_num in prange(self.chunks_n_threads, schedule='static'): + free(self.heaps_r_distances_chunks[thread_num]) + free(self.heaps_indices_chunks[thread_num]) + + # Sort the main heaps into arrays in parallel + # in ascending order w.r.t the distances + for idx in prange(self.n_samples_X, schedule='static'): + simultaneous_sort( + &self.argkmin_distances[idx, 0], + &self.argkmin_indices[idx, 0], + self.k, + ) + return + + cdef void compute_exact_distances(self) nogil: + cdef: + ITYPE_t i, j + ITYPE_t[:, ::1] Y_indices = self.argkmin_indices + DTYPE_t[:, ::1] distances = self.argkmin_distances + for i in prange(self.n_samples_X, schedule='static', nogil=True, + num_threads=self.effective_n_threads): + for j in range(self.k): + distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( + # Guard against eventual -0., causing nan production. + max(distances[i, j], 0.) + ) + + def _finalize_results(self, bint return_distance=False): + if return_distance: + # We need to recompute distances because we relied on + # surrogate distances for the reduction. + self.compute_exact_distances() + return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) + + return np.asarray(self.argkmin_indices) + + +cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin): + """Fast specialized alternative for PairwiseDistancesArgKmin on EuclideanDistance. + + The full pairwise squared distances matrix is computed as follows: + + ||X - Y||² = ||X||² - 2 X.Y^T + ||Y||² + + The middle term gets computed efficiently bellow using BLAS Level 3 GEMM. + + Notes + ----- + This implementation has a superior arithmetic intensity and hence + better running time when the alternative is IO bound, but it can suffer + from numerical instability caused by catastrophic cancellation potentially + introduced by the subtraction in the arithmetic expression above. + + PairwiseDistancesArgKmin with EuclideanDistance must be used when higher + numerical precision is needed. + """ + + cdef: + const DTYPE_t[:, ::1] X + const DTYPE_t[:, ::1] Y + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + # Buffers for GEMM + DTYPE_t ** dist_middle_terms_chunks + bint use_squared_distances + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + return (PairwiseDistancesArgKmin.is_usable_for(X, Y, metric) and + not _in_unstable_openblas_configuration()) + + def __init__( + self, + X, + Y, + ITYPE_t k, + bint use_squared_distances=False, + chunk_size=None, + strategy=None, + metric_kwargs=None, + ): + if metric_kwargs is not None and len(metric_kwargs) > 0: + warnings.warn( + f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't" + f"usable for this case ({self.__class__.__name__}) and will be ignored.", + UserWarning, + stacklevel=3, + ) + + super().__init__( + # The datasets pair here is used for exact distances computations + datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + k=k, + chunk_size=chunk_size, + ) + # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + cdef: + DenseDenseDatasetsPair datasets_pair = self.datasets_pair + self.X, self.Y = datasets_pair.X, datasets_pair.Y + + if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: + self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared", None) + else: + self.Y_norm_squared = _sqeuclidean_row_norms(self.Y, self.effective_n_threads) + + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms(self.X, self.effective_n_threads) + ) + self.use_squared_distances = use_squared_distances + + # Temporary datastructures used in threads + self.dist_middle_terms_chunks = malloc( + sizeof(DTYPE_t *) * self.chunks_n_threads + ) + + def __dealloc__(self): + if self.dist_middle_terms_chunks is not NULL: + free(self.dist_middle_terms_chunks) + + @final + cdef void compute_exact_distances(self) nogil: + if not self.use_squared_distances: + PairwiseDistancesArgKmin.compute_exact_distances(self) + + @final + cdef void _parallel_on_X_parallel_init( + self, + ITYPE_t thread_num, + ) nogil: + PairwiseDistancesArgKmin._parallel_on_X_parallel_init(self, thread_num) + + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_X_parallel_finalize( + self, + ITYPE_t thread_num + ) nogil: + PairwiseDistancesArgKmin._parallel_on_X_parallel_finalize(self, thread_num) + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _parallel_on_Y_parallel_init( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin._parallel_on_Y_parallel_init(self) + + for thread_num in range(self.chunks_n_threads): + # Temporary buffer for the `-2 * X_c @ Y_c.T` term + self.dist_middle_terms_chunks[thread_num] = malloc( + self.Y_n_samples_chunk * self.X_n_samples_chunk * sizeof(DTYPE_t) + ) + + @final + cdef void _parallel_on_Y_finalize( + self, + ) nogil: + cdef ITYPE_t thread_num + PairwiseDistancesArgKmin._parallel_on_Y_finalize(self) + + for thread_num in range(self.chunks_n_threads): + free(self.dist_middle_terms_chunks[thread_num]) + + @final + cdef void _compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + ITYPE_t i, j + + const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] + const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num] + DTYPE_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num] + ITYPE_t *heaps_indices = self.heaps_indices_chunks[thread_num] + + # Careful: LDA, LDB and LDC are given for F-ordered arrays + # in BLAS documentations, for instance: + # https://www.netlib.org/lapack/explore-html/db/dc9/group__single__blas__level3_gafe51bacb54592ff5de056acabd83c260.html #noqa + # + # Here, we use their counterpart values to work with C-ordered arrays. + BLAS_Order order = RowMajor + BLAS_Trans ta = NoTrans + BLAS_Trans tb = Trans + ITYPE_t m = X_c.shape[0] + ITYPE_t n = Y_c.shape[0] + ITYPE_t K = X_c.shape[1] + DTYPE_t alpha = - 2. + # Casting for A and B to remove the const is needed because APIs exposed via + # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. + DTYPE_t * A = & X_c[0, 0] + ITYPE_t lda = X_c.shape[1] + DTYPE_t * B = & Y_c[0, 0] + ITYPE_t ldb = X_c.shape[1] + DTYPE_t beta = 0. + DTYPE_t * C = dist_middle_terms + ITYPE_t ldc = Y_c.shape[0] + + # dist_middle_terms = `-2 * X_c @ Y_c.T` + _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, C, ldc) + + # Pushing the distance and their associated indices on heaps + # which keep tracks of the argkmin. + for i in range(X_c.shape[0]): + for j in range(Y_c.shape[0]): + heap_push( + heaps_r_distances + i * self.k, + heaps_indices + i * self.k, + self.k, + # Using the squared euclidean distance as the rank-preserving distance: + # + # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² + # + ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * Y_c.shape[0] + j] + + self.Y_norm_squared[j + Y_start] + ), + j + Y_start, + ) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 3292f1ff05767..36fdd513c76e4 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -19,6 +19,7 @@ from scipy.sparse import issparse from joblib import Parallel, effective_n_jobs +from .. import config_context from ..utils.validation import _num_samples from ..utils.validation import check_non_negative from ..utils import check_array @@ -31,6 +32,7 @@ from ..utils.fixes import delayed from ..utils.fixes import sp_version, parse_version +from ._pairwise_distances_reduction import PairwiseDistancesArgKmin from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan from ..exceptions import DataConversionWarning @@ -582,6 +584,10 @@ def _argmin_min_reduce(dist, start): return indices, values +def _argmin_reduce(dist, start): + return dist.argmin(axis=1) + + def pairwise_distances_argmin_min( X, Y, *, axis=1, metric="euclidean", metric_kwargs=None ): @@ -654,19 +660,38 @@ def pairwise_distances_argmin_min( """ X, Y = check_pairwise_arrays(X, Y) - if metric_kwargs is None: - metric_kwargs = {} - if axis == 0: X, Y = Y, X - indices, values = zip( - *pairwise_distances_chunked( - X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs + if metric_kwargs is None: + metric_kwargs = {} + + if PairwiseDistancesArgKmin.is_usable_for(X, Y, metric): + values, indices = PairwiseDistancesArgKmin.compute( + X=X, + Y=Y, + k=1, + metric=metric, + metric_kwargs=metric_kwargs, + strategy="auto", + return_distance=True, ) - ) - indices = np.concatenate(indices) - values = np.concatenate(values) + values = values.flatten() + indices = indices.flatten() + else: + # TODO: once PairwiseDistancesArgKmin supports sparse input matrices and 32 bit, + # we won't need to fallback to pairwise_distances_chunked anymore. + + # Turn off check for finiteness because this is costly and because arrays + # have already been validated. + with config_context(assume_finite=True): + indices, values = zip( + *pairwise_distances_chunked( + X, Y, reduce_func=_argmin_min_reduce, metric=metric, **metric_kwargs + ) + ) + indices = np.concatenate(indices) + values = np.concatenate(values) return indices, values @@ -738,9 +763,43 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs if metric_kwargs is None: metric_kwargs = {} - return pairwise_distances_argmin_min( - X, Y, axis=axis, metric=metric, metric_kwargs=metric_kwargs - )[0] + X, Y = check_pairwise_arrays(X, Y) + + if axis == 0: + X, Y = Y, X + + if metric_kwargs is None: + metric_kwargs = {} + + if PairwiseDistancesArgKmin.is_usable_for(X, Y, metric): + indices = PairwiseDistancesArgKmin.compute( + X=X, + Y=Y, + k=1, + metric=metric, + metric_kwargs=metric_kwargs, + strategy="auto", + return_distance=False, + ) + indices = indices.flatten() + else: + # TODO: once PairwiseDistancesArgKmin supports sparse input matrices and 32 bit, + # we won't need to fallback to pairwise_distances_chunked anymore. + + # Turn off check for finiteness because this is costly and because arrays + # have already been validated. + with config_context(assume_finite=True): + indices = np.concatenate( + list( + # This returns a np.ndarray generator whose arrays we need + # to flatten into one. + pairwise_distances_chunked( + X, Y, reduce_func=_argmin_reduce, metric=metric, **metric_kwargs + ) + ) + ) + + return indices def haversine_distances(X, Y=None): diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index 69925a3590be6..29d7c870202a1 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -19,6 +19,12 @@ def configuration(parent_package="", top_path=None): "_pairwise_fast", sources=["_pairwise_fast.pyx"], libraries=libraries ) + config.add_extension( + "_pairwise_distances_reduction", + sources=["_pairwise_distances_reduction.pyx"], + libraries=libraries, + ) + config.add_extension( "_dist_metrics", sources=["_dist_metrics.pyx"], diff --git a/sklearn/metrics/tests/test_dist_metrics.py b/sklearn/metrics/tests/test_dist_metrics.py index bf258ea564c8c..21dc1a9c76330 100644 --- a/sklearn/metrics/tests/test_dist_metrics.py +++ b/sklearn/metrics/tests/test_dist_metrics.py @@ -10,6 +10,7 @@ import scipy.sparse as sp from scipy.spatial.distance import cdist from sklearn.metrics import DistanceMetric +from sklearn.metrics._dist_metrics import BOOL_METRICS from sklearn.utils import check_random_state from sklearn.utils._testing import create_memmap_backed_data from sklearn.utils.fixes import sp_version, parse_version @@ -38,17 +39,6 @@ def dist_func(x1, x2, p): V = rng.random_sample((d, d)) VI = np.dot(V, V.T) -BOOL_METRICS = [ - "hamming", - "matching", - "jaccard", - "dice", - "kulsinski", - "rogerstanimoto", - "russellrao", - "sokalmichener", - "sokalsneath", -] METRICS_DEFAULT_PARAMS = [ ("euclidean", {}), @@ -75,6 +65,17 @@ def dist_func(x1, x2, p): ) +# TODO: remove this test in 1.3 +def test_neighbors_distance_metric_deprecation(): + from sklearn.neighbors import DistanceMetric as DeprecatedDistanceMetric + + with pytest.warns( + FutureWarning, match="sklearn.neighbors.DistanceMetric has been moved" + ): + DeprecatedDistanceMetric.get_metric("euclidean") + + +@pytest.mark.parametrize("metric", METRICS_DEFAULT_PARAMS) def check_cdist(metric, kwargs, X1, X2): if metric == "wminkowski": # wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0 diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index 93273ed915a28..4075e19c97d08 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -458,12 +458,12 @@ def test_pairwise_distances_argmin_min(): assert type(idxsp) == np.ndarray assert type(valssp) == np.ndarray - # euclidean metric squared - idx, vals = pairwise_distances_argmin_min( - X, Y, metric="euclidean", metric_kwargs={"squared": True} - ) + # Squared Euclidean metric + idx, vals = pairwise_distances_argmin_min(X, Y, metric="sqeuclidean") + idx2 = pairwise_distances_argmin(X, Y, metric="sqeuclidean") assert_array_almost_equal(idx, expected_idx) assert_array_almost_equal(vals, expected_vals_sq) + assert_array_almost_equal(idx2, expected_idx) # Non-euclidean scikit-learn metric idx, vals = pairwise_distances_argmin_min(X, Y, metric="manhattan") diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py new file mode 100644 index 0000000000000..b40ec64b895ea --- /dev/null +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -0,0 +1,421 @@ +import numpy as np +import pytest +from collections import defaultdict +from numpy.testing import assert_array_equal, assert_allclose +from scipy.sparse import csr_matrix + +from sklearn.metrics._pairwise_distances_reduction import ( + PairwiseDistancesReduction, + PairwiseDistancesArgKmin, + FastEuclideanPairwiseDistancesArgKmin, + _sqeuclidean_row_norms, +) + +from sklearn.utils import _in_unstable_openblas_configuration +from sklearn.utils.fixes import sp_version, parse_version +from sklearn.utils._testing import fails_if_unstable_openblas + + +def _get_dummy_metric_params_list(metric: str, n_features: int): + """Return list of dummy DistanceMetric kwargs for tests.""" + + rng = np.random.RandomState(1) + weights = rng.random_sample(n_features) + weights /= weights.sum() + + V = rng.random_sample((n_features, n_features)) + + # VI is positive-semidefinite, preferred for precision matrix + VI = np.dot(V, V.T) + 3 * np.eye(n_features) + + METRICS_PARAMS = defaultdict( + list, + { + "euclidean": [{}], + "manhattan": [{}], + "minkowski": [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)], + "chebyshev": [{}], + "seuclidean": [dict(V=rng.rand(n_features))], + "haversine": [{}], + "wminkowski": [dict(p=1.5, w=weights)], + "mahalanobis": [dict(VI=VI)], + }, + ) + + wminkowski_kwargs = dict(p=3, w=rng.rand(n_features)) + + if sp_version < parse_version("1.8.0.dev0"): + # TODO: remove once we no longer support scipy < 1.8.0. + # wminkowski was removed in scipy 1.8.0 but should work for previous + # versions. + METRICS_PARAMS["wminkowski"].append(wminkowski_kwargs) # type: ignore + else: + # Recent scipy versions accept weights in the Minkowski metric directly: + # type: ignore + METRICS_PARAMS["minkowski"].append(wminkowski_kwargs) # type: ignore + + return METRICS_PARAMS.get(metric, [{}]) + + +def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices): + assert_array_equal( + ref_indices, + indices, + err_msg="Query vectors have different neighbors' indices", + ) + assert_allclose( + ref_dist, + dist, + err_msg="Query vectors have different neighbors' distances", + rtol=1e-7, + ) + + +ASSERT_RESULT = { + PairwiseDistancesArgKmin: assert_argkmin_results_equality, +} + + +def test_pairwise_distances_reduction_is_usable_for(): + rng = np.random.RandomState(0) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + metric = "euclidean" + assert PairwiseDistancesReduction.is_usable_for(X, Y, metric) + assert not PairwiseDistancesReduction.is_usable_for( + X.astype(np.int64), Y.astype(np.int64), metric + ) + + assert not PairwiseDistancesReduction.is_usable_for(X[0], Y, metric) + assert not PairwiseDistancesReduction.is_usable_for(X, Y[0], metric) + + assert not PairwiseDistancesReduction.is_usable_for(X, Y, metric="pyfunc") + # TODO: remove once 32 bits datasets are supported + assert not PairwiseDistancesReduction.is_usable_for(X.astype(np.float32), Y, metric) + assert not PairwiseDistancesReduction.is_usable_for(X, Y.astype(np.int32), metric) + + # TODO: remove once sparse matrices are supported + assert not PairwiseDistancesReduction.is_usable_for(csr_matrix(X), Y, metric) + assert not PairwiseDistancesReduction.is_usable_for(X, csr_matrix(Y), metric) + + +def test_argkmin_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + k = 5 + metric = "euclidean" + + with pytest.raises( + ValueError, match="Only 64bit float datasets are supported for X and Y." + ): + PairwiseDistancesArgKmin.compute( + X=X.astype(np.float32), Y=Y, k=k, metric=metric + ) + + with pytest.raises( + ValueError, match="Only 64bit float datasets are supported for X and Y." + ): + PairwiseDistancesArgKmin.compute(X=X, Y=Y.astype(np.int32), k=k, metric=metric) + + with pytest.raises(ValueError, match="k == -1, must be >= 1."): + PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=-1, metric=metric) + + with pytest.raises(ValueError, match="k == 0, must be >= 1."): + PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=0, metric=metric) + + with pytest.raises(ValueError, match="Unrecognized metric"): + PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=k, metric="wrong metric") + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + PairwiseDistancesArgKmin.compute( + X=np.array([1.0, 2.0]), Y=Y, k=k, metric=metric + ) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + PairwiseDistancesArgKmin.compute( + X=np.asfortranarray(X), Y=Y, k=k, metric=metric + ) + + +@fails_if_unstable_openblas +@pytest.mark.filterwarnings("ignore:Constructing a DIA matrix") +@pytest.mark.parametrize( + "PairwiseDistancesReduction, FastPairwiseDistancesReduction", + [ + (PairwiseDistancesArgKmin, FastEuclideanPairwiseDistancesArgKmin), + ], +) +def test_pairwise_distances_reduction_factory_method( + PairwiseDistancesReduction, FastPairwiseDistancesReduction +): + # Test all the combinations of DatasetsPair for creation + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + metric = "euclidean" + + # Dummy value for k or radius + dummy_arg = 5 + + with pytest.raises( + ValueError, match="Only dense datasets are supported for X and Y." + ): + PairwiseDistancesReduction.compute( + csr_matrix(X), + csr_matrix(Y), + dummy_arg, + metric, + ) + + with pytest.raises( + ValueError, match="Only dense datasets are supported for X and Y." + ): + PairwiseDistancesReduction.compute(X, csr_matrix(Y), dummy_arg, metric=metric) + + with pytest.raises( + ValueError, match="Only dense datasets are supported for X and Y." + ): + PairwiseDistancesReduction.compute(csr_matrix(X), Y, dummy_arg, metric=metric) + + +@fails_if_unstable_openblas +@pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("chunk_size", [50, 512, 1024]) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_chunk_size_agnosticism( + PairwiseDistancesReduction, + seed, + n_samples, + chunk_size, + n_features=100, + dtype=np.float64, +): + # Results should not depend on the chunk size + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + ref_dist, ref_indices = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + return_distance=True, + ) + + dist, indices = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + chunk_size=chunk_size, + return_distance=True, + ) + + ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices) + + +@fails_if_unstable_openblas +@pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("chunk_size", [50, 512, 1024]) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_n_threads_agnosticism( + PairwiseDistancesReduction, + seed, + n_samples, + chunk_size, + n_features=100, + dtype=np.float64, +): + # Results should not depend on the number of threads + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + ref_dist, ref_indices = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + return_distance=True, + ) + + dist, indices = PairwiseDistancesReduction.compute( + X, Y, parameter, n_threads=1, return_distance=True + ) + + ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices) + + +@pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("metric", PairwiseDistancesReduction.valid_metrics()) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_strategies_consistency( + PairwiseDistancesReduction, + metric, + n_samples, + seed, + n_features=10, + dtype=np.float64, +): + # Results obtained using both parallelization strategies must be identical + if _in_unstable_openblas_configuration() and metric in ("sqeuclidean", "euclidean"): + pytest.xfail( + "OpenBLAS (used for '(sq)euclidean') is unstable in this configuration" + ) + + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) + Y = np.ascontiguousarray(Y[:, :2]) + + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + dist_par_X, indices_par_X = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + metric=metric, + # Taking the first + metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + # To be sure to use parallelization + chunk_size=n_samples // 4, + strategy="parallel_on_X", + return_distance=True, + ) + + dist_par_Y, indices_par_Y = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + metric=metric, + # Taking the first + metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + # To be sure to use parallelization + chunk_size=n_samples // 4, + strategy="parallel_on_Y", + return_distance=True, + ) + + ASSERT_RESULT[PairwiseDistancesReduction]( + dist_par_X, + dist_par_Y, + indices_par_X, + indices_par_Y, + ) + + +@fails_if_unstable_openblas +@pytest.mark.parametrize("n_features", [50, 500]) +@pytest.mark.parametrize("translation", [10 ** i for i in [4, 8]]) +@pytest.mark.parametrize("metric", PairwiseDistancesReduction.valid_metrics()) +@pytest.mark.parametrize( + "PairwiseDistancesReduction", + [PairwiseDistancesArgKmin], +) +def test_euclidean_translation_invariance( + n_features, + translation, + metric, + PairwiseDistancesReduction, + n_samples=1000, + dtype=np.float64, +): + # The reduction must be translation invariant. + parameter = ( + 10 + if PairwiseDistancesReduction is PairwiseDistancesArgKmin + # Scaling the radius slightly with the numbers of dimensions + else 10 ** np.log(n_features) + ) + + rng = np.random.RandomState(0) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + Y = rng.rand(n_samples, n_features).astype(dtype) * spread + + # Haversine distance only accepts 2D data + if metric == "haversine": + X = np.ascontiguousarray(X[:, :2]) + Y = np.ascontiguousarray(Y[:, :2]) + + reference_dist, reference_indices = PairwiseDistancesReduction.compute( + X, + Y, + parameter, + metric=metric, + metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + return_distance=True, + ) + + dist, indices = PairwiseDistancesReduction.compute( + X + 0, + Y + 0, + parameter, + metric=metric, + metric_kwargs=_get_dummy_metric_params_list(metric, n_features)[0], + return_distance=True, + ) + + ASSERT_RESULT[PairwiseDistancesReduction]( + reference_dist, dist, reference_indices, indices + ) + + +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("n_samples", [100, 1000]) +@pytest.mark.parametrize("n_features", [5, 10, 100]) +@pytest.mark.parametrize("num_threads", [1, 2, 8]) +def test_sqeuclidean_row_norms( + seed, + n_samples, + n_features, + num_threads, + dtype=np.float64, +): + rng = np.random.RandomState(seed) + spread = 100 + X = rng.rand(n_samples, n_features).astype(dtype) * spread + + sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2 + sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads)) + + assert_allclose(sq_row_norm_reference, sq_row_norm) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 8adb58b4f8c6c..cf5a4a447d0ac 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -22,6 +22,9 @@ from ..base import is_classifier from ..metrics import pairwise_distances_chunked from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS +from ..metrics._pairwise_distances_reduction import ( + PairwiseDistancesArgKmin, +) from ..utils import ( check_array, gen_even_slices, @@ -353,31 +356,33 @@ def _check_algorithm_metric(self): if self.algorithm not in ["auto", "brute", "kd_tree", "ball_tree"]: raise ValueError("unrecognized algorithm: '%s'" % self.algorithm) + self._metric = self.metric + if self.algorithm == "auto": - if self.metric == "precomputed": + if self._metric == "precomputed": alg_check = "brute" - elif callable(self.metric) or self.metric in VALID_METRICS["ball_tree"]: + elif callable(self._metric) or self._metric in VALID_METRICS["ball_tree"]: alg_check = "ball_tree" else: alg_check = "brute" else: alg_check = self.algorithm - if callable(self.metric): + if callable(self._metric): if self.algorithm == "kd_tree": # callable metric is only valid for brute force and ball_tree raise ValueError( "kd_tree does not support callable metric '%s'" "Function call overhead will result" "in very poor performance." - % self.metric + % self._metric ) - elif self.metric not in VALID_METRICS[alg_check]: + elif self._metric not in VALID_METRICS[alg_check]: raise ValueError( "Metric '%s' not valid. Use " "sorted(sklearn.neighbors.VALID_METRICS['%s']) " "to get valid options. " - "Metric can also be a callable function." % (self.metric, alg_check) + "Metric can also be a callable function." % (self._metric, alg_check) ) if self.metric_params is not None and "p" in self.metric_params: @@ -393,7 +398,7 @@ def _check_algorithm_metric(self): else: effective_p = self.p - if self.metric in ["wminkowski", "minkowski"] and effective_p < 1: + if self._metric in ["wminkowski", "minkowski"] and effective_p < 1: raise ValueError("p must be greater or equal to one for minkowski metric") def _fit(self, X, y=None): @@ -443,12 +448,12 @@ def _fit(self, X, y=None): self.effective_metric_params_ = self.metric_params.copy() effective_p = self.effective_metric_params_.get("p", self.p) - if self.metric in ["wminkowski", "minkowski"]: + if self._metric in ["wminkowski", "minkowski"]: self.effective_metric_params_["p"] = effective_p - self.effective_metric_ = self.metric + self.effective_metric_ = self._metric # For minkowski distance, use more efficient methods where available - if self.metric == "minkowski": + if self._metric == "minkowski": p = self.effective_metric_params_.pop("p", 2) w = self.effective_metric_params_.pop("w", None) if p < 1: @@ -487,7 +492,7 @@ def _fit(self, X, y=None): self.n_samples_fit_ = X.data.shape[0] return self - if self.metric == "precomputed": + if self._metric == "precomputed": X = _check_precomputed(X) # Precomputed matrix X must be squared if X.shape[0] != X.shape[1]: @@ -504,6 +509,7 @@ def _fit(self, X, y=None): if issparse(X): if self.algorithm not in ("auto", "brute"): warnings.warn("cannot use tree with sparse input: using brute force") + if self.effective_metric_ not in VALID_METRICS_SPARSE[ "brute" ] and not callable(self.effective_metric_): @@ -528,7 +534,7 @@ def _fit(self, X, y=None): # A tree approach is better for small number of neighbors or small # number of features, with KDTree generally faster when available if ( - self.metric == "precomputed" + self._metric == "precomputed" or self._fit_X.shape[1] > 15 or ( self.n_neighbors is not None @@ -651,10 +657,7 @@ def _kneighbors_reduce_func(self, dist, start, n_neighbors, return_distance): # argpartition doesn't guarantee sorted order, so we sort again neigh_ind = neigh_ind[sample_range, np.argsort(dist[sample_range, neigh_ind])] if return_distance: - if self.effective_metric_ == "euclidean": - result = np.sqrt(dist[sample_range, neigh_ind]), neigh_ind - else: - result = dist[sample_range, neigh_ind], neigh_ind + result = dist[sample_range, neigh_ind], neigh_ind else: result = neigh_ind return result @@ -724,18 +727,37 @@ class from an array representing our data set and ask who's % type(n_neighbors) ) - if X is not None: - query_is_train = False - if self.metric == "precomputed": - X = _check_precomputed(X) - else: - X = self._validate_data(X, accept_sparse="csr", reset=False) - else: - query_is_train = True + use_pairwise_distances_reductions = ( + self._fit_method == "brute" + and PairwiseDistancesArgKmin.is_usable_for( + X if X is not None else self._fit_X, self._fit_X, self.effective_metric_ + ) + ) + + query_is_train = X is None + if query_is_train: + if use_pairwise_distances_reductions: + # We force the C-contiguity even if it creates a copy for F-ordered + # arrays because PairwiseDistancesArgKmin is more efficient. + self._fit_X = self._validate_data( + self._fit_X, accept_sparse="csr", reset=False, order="C" + ) X = self._fit_X # Include an extra neighbor to account for the sample itself being # returned, which is removed later n_neighbors += 1 + else: + if use_pairwise_distances_reductions: + # We force the C-contiguity even if it creates a copy for F-ordered + # arrays because PairwiseDistancesArgKmin is more efficient. + X = self._validate_data(X, accept_sparse="csr", reset=False, order="C") + self._fit_X = self._validate_data( + self._fit_X, accept_sparse="csr", reset=False, order="C" + ) + elif self._metric == "precomputed": + X = _check_precomputed(X) + else: + X = self._validate_data(X, accept_sparse="csr", reset=False) n_samples_fit = self.n_samples_fit_ if n_neighbors > n_samples_fit: @@ -746,24 +768,36 @@ class from an array representing our data set and ask who's n_jobs = effective_n_jobs(self.n_jobs) chunked_results = None - if self._fit_method == "brute" and self.metric == "precomputed" and issparse(X): + if use_pairwise_distances_reductions: + results = PairwiseDistancesArgKmin.compute( + X=X, + Y=self._fit_X, + k=n_neighbors, + metric=self.effective_metric_, + metric_kwargs=self.effective_metric_params_, + n_threads=self.n_jobs, + strategy="auto", + return_distance=return_distance, + ) + + elif ( + self._fit_method == "brute" + and self._metric == "precomputed" + and issparse(X) + ): results = _kneighbors_from_graph( X, n_neighbors=n_neighbors, return_distance=return_distance ) elif self._fit_method == "brute": + # TODO: support sparse matrices + reduce_func = partial( self._kneighbors_reduce_func, n_neighbors=n_neighbors, return_distance=return_distance, ) - # for efficiency, use squared euclidean distances - if self.effective_metric_ == "euclidean": - kwds = {"squared": True} - else: - kwds = self.effective_metric_params_ - chunked_results = list( pairwise_distances_chunked( X, @@ -771,7 +805,7 @@ class from an array representing our data set and ask who's reduce_func=reduce_func, metric=self.effective_metric_, n_jobs=n_jobs, - **kwds, + **self.effective_metric_params_, ) ) @@ -1061,6 +1095,8 @@ class from an array representing our data set and ask who's ) elif self._fit_method == "brute": + # TODO: support sparse matrices + # for efficiency, use squared euclidean distances if self.effective_metric_ == "euclidean": radius *= radius diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index b3f02cd74d7b5..649a92548eed3 100644 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -758,7 +758,7 @@ def newObj(obj): ###################################################################### # define the reverse mapping of VALID_METRICS -from sklearn.metrics._dist_metrics import get_valid_metric_ids +from ..metrics._dist_metrics import get_valid_metric_ids VALID_METRIC_IDS = get_valid_metric_ids(VALID_METRICS) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index bcad8c71aee07..eaa2ba05eb55a 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -65,7 +65,7 @@ class KNeighborsClassifier(KNeighborsMixin, ClassifierMixin, NeighborsBase): (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used. metric : str or callable, default='minkowski' - The distance metric to use for the tree. The default metric is + The distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean metric. For a list of available metrics, see the documentation of :class:`~sklearn.metrics.DistanceMetric` and the metrics listed in @@ -628,6 +628,7 @@ def predict_proba(self, X): n_queries = _num_samples(X) neigh_dist, neigh_ind = self.radius_neighbors(X) + outlier_mask = np.zeros(n_queries, dtype=bool) outlier_mask[:] = [len(nind) == 0 for nind in neigh_ind] outliers = np.flatnonzero(outlier_mask) diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index c94fbebe37704..fd934ad054b6e 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -378,8 +378,6 @@ def make_train_test(X_train, X_test): estimators = [ neighbors.KNeighborsClassifier, neighbors.KNeighborsRegressor, - neighbors.RadiusNeighborsClassifier, - neighbors.RadiusNeighborsRegressor, ] check_precomputed(make_train_test, estimators) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index e469f23104398..cccaf7c44d2a7 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -14,6 +14,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": 256, } # Not using as a context manager affects nothing @@ -26,6 +27,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": 256, } assert get_config()["assume_finite"] is False @@ -55,6 +57,7 @@ def test_config_context(): "working_memory": 1024, "print_changed_only": True, "display": "text", + "pairwise_dist_chunk_size": 256, } # No positional arguments diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8c69f55dee917..932f91a433b47 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -26,7 +26,7 @@ from . import _joblib from ..exceptions import DataConversionWarning from .deprecation import deprecated -from .fixes import np_version, parse_version +from .fixes import np_version, parse_version, threadpool_info from ._estimator_html_repr import estimator_html_repr from .validation import ( as_float_array, @@ -83,6 +83,39 @@ _IS_32BIT = 8 * struct.calcsize("P") == 32 +def _in_unstable_openblas_configuration(): + """Return True if in an unstable configuration for OpenBLAS""" + + # Import libraries which might load OpenBLAS. + import numpy # noqa + import scipy # noqa + + modules_info = threadpool_info() + + open_blas_used = any(info["internal_api"] == "openblas" for info in modules_info) + if not open_blas_used: + return False + + # OpenBLAS 0.3.16 fixed unstability for arm64, see: + # https://github.com/xianyi/OpenBLAS/blob/1b6db3dbba672b4f8af935bd43a1ff6cff4d20b7/Changelog.txt#L56-L58 # noqa + openblas_arm64_stable_version = parse_version("0.3.16") + for info in modules_info: + if info["internal_api"] != "openblas": + continue + openblas_version = info.get("version") + openblas_architecture = info.get("architecture") + if openblas_version is None or openblas_architecture is None: + # Cannot be sure that OpenBLAS is good enough. Assume unstable: + return True + if ( + openblas_architecture == "neoversen1" + and parse_version(openblas_version) < openblas_arm64_stable_version + ): + # See discussions in https://github.com/numpy/numpy/issues/19411 + return True + return False + + def safe_mask(X, mask): """Return a mask which is safe to use on X. diff --git a/sklearn/utils/_fast_dict.pyx b/sklearn/utils/_fast_dict.pyx index 6d7e62eefc07f..9d9234682f02a 100644 --- a/sklearn/utils/_fast_dict.pyx +++ b/sklearn/utils/_fast_dict.pyx @@ -68,7 +68,7 @@ cdef class IntFloatDict: # while it != end: # yield deref(it).first, deref(it).second # inc(it) - + def __iter__(self): cdef int size = self.my_map.size() cdef ITYPE_t [:] keys = np.empty(size, dtype=np.intp) @@ -147,4 +147,3 @@ def argmin(IntFloatDict d): min_key = deref(it).first inc(it) return min_key, min_value - diff --git a/sklearn/utils/_openmp_helpers.pxd b/sklearn/utils/_openmp_helpers.pxd new file mode 100644 index 0000000000000..e57fc9bfa6bf5 --- /dev/null +++ b/sklearn/utils/_openmp_helpers.pxd @@ -0,0 +1,6 @@ +# Helpers to access OpenMP threads information +# +# Those interfaces act as indirections which allows the non-support of OpenMP +# for implementations which have been written for it. + +cdef int _openmp_thread_num() nogil diff --git a/sklearn/utils/_openmp_helpers.pyx b/sklearn/utils/_openmp_helpers.pyx index fb8920074a84e..cddd77ac42746 100644 --- a/sklearn/utils/_openmp_helpers.pyx +++ b/sklearn/utils/_openmp_helpers.pyx @@ -6,7 +6,7 @@ IF SKLEARN_OPENMP_PARALLELISM_ENABLED: def _openmp_parallelism_enabled(): """Determines whether scikit-learn has been built with OpenMP - + It allows to retrieve at runtime the information gathered at compile time. """ # SKLEARN_OPENMP_PARALLELISM_ENABLED is resolved at compile time during @@ -22,7 +22,7 @@ cpdef _openmp_effective_n_threads(n_threads=None): - if the ``OMP_NUM_THREADS`` environment variable is set, return ``openmp.omp_get_max_threads()`` - otherwise, return the minimum between ``openmp.omp_get_max_threads()`` - and the number of cpus, taking cgroups quotas into account. Cgroups + and the number of cpus, taking cgroups quotas into account. Cgroups quotas can typically be set by tools such as Docker. The result of ``omp_get_max_threads`` can be influenced by environment variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``. @@ -59,4 +59,13 @@ cpdef _openmp_effective_n_threads(n_threads=None): # OpenMP disabled at build-time => sequential mode return 1 - + +cdef inline int _openmp_thread_num() nogil: + """Return the number of the thread calling this function. + + If scikit-learn is built without OpenMP support, always return 0. + """ + IF SKLEARN_OPENMP_PARALLELISM_ENABLED: + return openmp.omp_get_thread_num() + ELSE: + return 0 diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 7b57549e5886c..ac84bf058df8c 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -48,7 +48,12 @@ import joblib import sklearn -from sklearn.utils import IS_PYPY, _IS_32BIT, deprecated +from sklearn.utils import ( + IS_PYPY, + _IS_32BIT, + deprecated, + _in_unstable_openblas_configuration, +) from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import ( check_array, @@ -448,6 +453,10 @@ def set_random_state(estimator, random_state=0): os.environ.get("TRAVIS") == "true", reason="skip on travis" ) fails_if_pypy = pytest.mark.xfail(IS_PYPY, reason="not compatible with PyPy") + fails_if_unstable_openblas = pytest.mark.xfail( + _in_unstable_openblas_configuration(), + reason="OpenBLAS is unstable for this configuration", + ) skip_if_no_parallel = pytest.mark.skipif( not joblib.parallel.mp, reason="joblib is in serial mode" )