diff --git a/.gitignore b/.gitignore index f3d2dc08ca954..24f562af3df15 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,13 @@ sklearn/utils/_weight_vector.pxd sklearn/linear_model/_sag_fast.pyx sklearn/metrics/_dist_metrics.pyx sklearn/metrics/_dist_metrics.pxd +sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd +sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx +sklearn/metrics/_pairwise_distances_reduction/_base.pxd +sklearn/metrics/_pairwise_distances_reduction/_base.pyx +sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd +sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx +sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd +sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx +sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd +sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 4417987640f94..952d2867360a3 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -296,7 +296,11 @@ Changelog For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors` and :class:`sklearn.neighbors.NearestNeighbors.radius_neighbors` - can respectively be up to ×20 and ×5 faster than previously. + can respectively be up to ×20 and ×5 faster than previously on a laptop. + + Moreover, implementations of those two algorithms are now suitable + for machine with many cores, making them usable for datasets consisting + of millions of samples. :pr:`21987`, :pr:`22064`, :pr:`22065`, :pr:`22288` and :pr:`22320` by :user:`Julien Jerphanion `. diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 18bb775a0b02e..750b9971c8801 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -73,6 +73,42 @@ Changelog parameter `base_estimator` is deprecated and will be removed in 1.4. :pr:`22054` by :user:`Kevin Roice `. +- |Efficiency| Low-level routines for reductions on pairwise distances + for dense float32 datasets have been refactored. The following functions + and estimators now benefit from improved performances in terms of hardware + scalability and speed-ups: + + - :func:`sklearn.metrics.pairwise_distances_argmin` + - :func:`sklearn.metrics.pairwise_distances_argmin_min` + - :class:`sklearn.cluster.AffinityPropagation` + - :class:`sklearn.cluster.Birch` + - :class:`sklearn.cluster.MeanShift` + - :class:`sklearn.cluster.OPTICS` + - :class:`sklearn.cluster.SpectralClustering` + - :func:`sklearn.feature_selection.mutual_info_regression` + - :class:`sklearn.neighbors.KNeighborsClassifier` + - :class:`sklearn.neighbors.KNeighborsRegressor` + - :class:`sklearn.neighbors.RadiusNeighborsClassifier` + - :class:`sklearn.neighbors.RadiusNeighborsRegressor` + - :class:`sklearn.neighbors.LocalOutlierFactor` + - :class:`sklearn.neighbors.NearestNeighbors` + - :class:`sklearn.manifold.Isomap` + - :class:`sklearn.manifold.LocallyLinearEmbedding` + - :class:`sklearn.manifold.TSNE` + - :func:`sklearn.manifold.trustworthiness` + - :class:`sklearn.semi_supervised.LabelPropagation` + - :class:`sklearn.semi_supervised.LabelSpreading` + + For instance :class:`sklearn.neighbors.NearestNeighbors.kneighbors` and + :class:`sklearn.neighbors.NearestNeighbors.radius_neighbors` + can respectively be up to ×20 and ×5 faster than previously on a laptop. + + Moreover, implementations of those two algorithms are now suitable + for machine with many cores, making them usable for datasets consisting + of millions of samples. + + :pr:`23865` by :user:`Julien Jerphanion `. + :mod:`sklearn.cluster` ...................... diff --git a/setup.cfg b/setup.cfg index 21b9db50de2cb..35bc27c410976 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,17 @@ ignore = sklearn/utils/_weight_vector.pxd sklearn/metrics/_dist_metrics.pyx sklearn/metrics/_dist_metrics.pxd + sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd + sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx + sklearn/metrics/_pairwise_distances_reduction/_base.pxd + sklearn/metrics/_pairwise_distances_reduction/_base.pyx + sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd + sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx + sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd + sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx + sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd + sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx + [codespell] skip = ./.git,./.mypy_cache,./doc/themes/scikit-learn-modern/static/js,./sklearn/feature_extraction/_stop_words.py,./doc/_build,./doc/auto_examples,./doc/modules/generated diff --git a/sklearn/manifold/tests/test_t_sne.py b/sklearn/manifold/tests/test_t_sne.py index 882bb0c4f8023..38387142ccf50 100644 --- a/sklearn/manifold/tests/test_t_sne.py +++ b/sklearn/manifold/tests/test_t_sne.py @@ -1048,32 +1048,54 @@ def test_gradient_bh_multithread_match_sequential(): assert_allclose(grad_multithread, grad_multithread) -def test_tsne_with_different_distance_metrics(): +@pytest.mark.parametrize( + "metric, dist_func", + [("manhattan", manhattan_distances), ("cosine", cosine_distances)], +) +@pytest.mark.parametrize("method", ["barnes_hut", "exact"]) +def test_tsne_with_different_distance_metrics(metric, dist_func, method): """Make sure that TSNE works for different distance metrics""" + + if method == "barnes_hut" and metric == "manhattan": + # The distances computed by `manhattan_distances` differ slightly from those + # computed internally by NearestNeighbors via the PairwiseDistancesReduction + # Cython code-based. This in turns causes T-SNE to converge to a different + # solution but this should not impact the qualitative results as both + # methods. + # NOTE: it's probably not valid from a mathematical point of view to use the + # Manhattan distance for T-SNE... + # TODO: re-enable this test if/when `manhattan_distances` is refactored to + # reuse the same underlying Cython code NearestNeighbors. + # For reference, see: + # https://github.com/scikit-learn/scikit-learn/pull/23865/files#r925721573 + pytest.xfail( + "Distance computations are different for method == 'barnes_hut' and metric" + " == 'manhattan', but this is expected." + ) + random_state = check_random_state(0) n_components_original = 3 n_components_embedding = 2 X = random_state.randn(50, n_components_original).astype(np.float32) - metrics = ["manhattan", "cosine"] - dist_funcs = [manhattan_distances, cosine_distances] - for metric, dist_func in zip(metrics, dist_funcs): - X_transformed_tsne = TSNE( - metric=metric, - n_components=n_components_embedding, - random_state=0, - n_iter=300, - init="random", - learning_rate="auto", - ).fit_transform(X) - X_transformed_tsne_precomputed = TSNE( - metric="precomputed", - n_components=n_components_embedding, - random_state=0, - n_iter=300, - init="random", - learning_rate="auto", - ).fit_transform(dist_func(X)) - assert_array_equal(X_transformed_tsne, X_transformed_tsne_precomputed) + X_transformed_tsne = TSNE( + metric=metric, + method=method, + n_components=n_components_embedding, + random_state=0, + n_iter=300, + init="random", + learning_rate="auto", + ).fit_transform(X) + X_transformed_tsne_precomputed = TSNE( + metric="precomputed", + method=method, + n_components=n_components_embedding, + random_state=0, + n_iter=300, + init="random", + learning_rate="auto", + ).fit_transform(dist_func(X)) + assert_array_equal(X_transformed_tsne, X_transformed_tsne_precomputed) # TODO: Remove in 1.2 diff --git a/sklearn/metrics/_dist_metrics.pxd.tp b/sklearn/metrics/_dist_metrics.pxd.tp index 8f8aa21107015..8e972435b2951 100644 --- a/sklearn/metrics/_dist_metrics.pxd.tp +++ b/sklearn/metrics/_dist_metrics.pxd.tp @@ -72,7 +72,7 @@ cdef inline DTYPE_t euclidean_rdist_to_dist{{name_suffix}}(const {{INPUT_DTYPE_t ###################################################################### -# DistanceMetric base class +# DistanceMetric{{name_suffix}} base class cdef class DistanceMetric{{name_suffix}}: # The following attributes are required for a few of the subclasses. # we must define them here so that cython's limited polymorphism will work. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd deleted file mode 100644 index 34d3339e1c9e0..0000000000000 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd +++ /dev/null @@ -1,33 +0,0 @@ -cimport numpy as cnp - -from ._base cimport ( - PairwiseDistancesReduction64, -) -from ._gemm_term_computer cimport GEMMTermComputer64 - -from ...utils._typedefs cimport ITYPE_t, DTYPE_t - -cnp.import_array() - -cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" - - cdef: - ITYPE_t k - - ITYPE_t[:, ::1] argkmin_indices - DTYPE_t[:, ::1] argkmin_distances - - # Used as array of pointers to private datastructures used in threads. - DTYPE_t ** heaps_r_distances_chunks - ITYPE_t ** heaps_indices_chunks - - -cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" - cdef: - GEMMTermComputer64 gemm_term_computer - const DTYPE_t[::1] X_norm_squared - const DTYPE_t[::1] Y_norm_squared - - bint use_squared_distances diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp new file mode 100644 index 0000000000000..4a7d4db953391 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp @@ -0,0 +1,50 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'cnp.float64_t', 'np.float64'), + ('32', 'cnp.float32_t', 'np.float32') +] + +}} + +cimport numpy as cnp +from ...utils._typedefs cimport ITYPE_t, DTYPE_t + +cnp.import_array() + +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + +from ._base cimport PairwiseDistancesReduction{{name_suffix}} +from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} + +cdef class PairwiseDistancesArgKmin{{name_suffix}}(PairwiseDistancesReduction{{name_suffix}}): + """{{name_suffix}}bit implementation of PairwiseDistancesArgKmin.""" + + cdef: + ITYPE_t k + + ITYPE_t[:, ::1] argkmin_indices + DTYPE_t[:, ::1] argkmin_distances + + # Used as array of pointers to private datastructures used in threads. + DTYPE_t ** heaps_r_distances_chunks + ITYPE_t ** heaps_indices_chunks + + +cdef class FastEuclideanPairwiseDistancesArgKmin{{name_suffix}}(PairwiseDistancesArgKmin{{name_suffix}}): + """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistancesArgKmin.""" + cdef: + GEMMTermComputer{{name_suffix}} gemm_term_computer + const DTYPE_t[::1] X_norm_squared + const DTYPE_t[::1] Y_norm_squared + + bint use_squared_distances + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp similarity index 82% rename from sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx rename to sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index 2f378543e1f97..8cdd7e1687e8c 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -1,3 +1,19 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DTYPE_t', 'DTYPE'), + ('32', 'cnp.float32_t', 'np.float32') +] + +}} + cimport numpy as cnp from libc.stdlib cimport free, malloc @@ -5,18 +21,6 @@ from libc.float cimport DBL_MAX from cython cimport final from cython.parallel cimport parallel, prange -from ._base cimport ( - PairwiseDistancesReduction64, - _sqeuclidean_row_norms64, -) - -from ._datasets_pair cimport ( - DatasetsPair, - DenseDenseDatasetsPair, -) - -from ._gemm_term_computer cimport GEMMTermComputer64 - from ...utils._heap cimport heap_push from ...utils._sorting cimport simultaneous_sort from ...utils._typedefs cimport ITYPE_t, DTYPE_t @@ -26,15 +30,30 @@ import warnings from numbers import Integral from scipy.sparse import issparse -from sklearn.utils import check_scalar, _in_unstable_openblas_configuration -from sklearn.utils.fixes import threadpool_limits +from ...utils import check_array, check_scalar, _in_unstable_openblas_configuration +from ...utils.fixes import threadpool_limits from ...utils._typedefs import ITYPE, DTYPE + cnp.import_array() +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} -cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesArgKmin.""" +from ._base cimport ( + PairwiseDistancesReduction{{name_suffix}}, + _sqeuclidean_row_norms{{name_suffix}}, +) + +from ._datasets_pair cimport ( + DatasetsPair{{name_suffix}}, + DenseDenseDatasetsPair{{name_suffix}}, +) + +from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} + + +cdef class PairwiseDistancesArgKmin{{name_suffix}}(PairwiseDistancesReduction{{name_suffix}}): + """{{name_suffix}}bit implementation of PairwiseDistancesArgKmin.""" @classmethod def compute( @@ -52,7 +71,7 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): This classmethod is responsible for introspecting the arguments values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesArgKmin64`. + :class:`PairwiseDistancesArgKmin{{name_suffix}}`. This allows decoupling the API entirely from the implementation details whilst maintaining RAII: all temporarily allocated datastructures necessary @@ -71,7 +90,7 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): # at time to leverage a call to the BLAS GEMM routine as explained # in more details in the docstring. use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesArgKmin64( + pda = FastEuclideanPairwiseDistancesArgKmin{{name_suffix}}( X=X, Y=Y, k=k, use_squared_distances=use_squared_distances, chunk_size=chunk_size, @@ -81,8 +100,8 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): else: # Fall back on a generic implementation that handles most scipy # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesArgKmin64( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + pda = PairwiseDistancesArgKmin{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), k=k, chunk_size=chunk_size, strategy=strategy, @@ -100,7 +119,7 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): def __init__( self, - DatasetsPair datasets_pair, + DatasetsPair{{name_suffix}} datasets_pair, chunk_size=None, strategy=None, ITYPE_t k=1, @@ -128,7 +147,8 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): sizeof(ITYPE_t *) * self.chunks_n_threads ) - # Main heaps which will be returned as results by `PairwiseDistancesArgKmin64.compute`. + # Main heaps which will be returned as results by + # `PairwiseDistancesArgKmin{{name_suffix}}.compute`. self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=ITYPE) self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=DTYPE) @@ -302,18 +322,19 @@ cdef class PairwiseDistancesArgKmin64(PairwiseDistancesReduction64): # Values are returned identically to the way `KNeighborsMixin.kneighbors` # returns values. This is counter-intuitive but this allows not using - # complex adaptations where `PairwiseDistancesArgKmin64.compute` is called. + # complex adaptations where + # `PairwiseDistancesArgKmin{{name_suffix}}.compute` is called. return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) return np.asarray(self.argkmin_indices) -cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesArgKmin.""" +cdef class FastEuclideanPairwiseDistancesArgKmin{{name_suffix}}(PairwiseDistancesArgKmin{{name_suffix}}): + """EuclideanDistance-specialized {{name_suffix}} bit implementation for PairwiseDistancesArgKmin.""" @classmethod def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesArgKmin64.is_usable_for(X, Y, metric) and + return (PairwiseDistancesArgKmin{{name_suffix}}.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) def __init__( @@ -340,19 +361,20 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): super().__init__( # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), chunk_size=chunk_size, strategy=strategy, k=k, ) - # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + # X and Y are checked by the DatasetsPair{{name_suffix}} implemented + # as a DenseDenseDatasetsPair{{name_suffix}} cdef: - DenseDenseDatasetsPair datasets_pair = ( - self.datasets_pair + DenseDenseDatasetsPair{{name_suffix}} datasets_pair = ( + self.datasets_pair ) ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - self.gemm_term_computer = GEMMTermComputer64( + self.gemm_term_computer = GEMMTermComputer{{name_suffix}}( datasets_pair.X, datasets_pair.Y, self.effective_n_threads, @@ -363,28 +385,33 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): ) if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + self.Y_norm_squared = check_array( + metric_kwargs.pop("Y_norm_squared"), + ensure_2d=False, + input_name="Y_norm_squared", + dtype=np.float64 + ) else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.Y, self.effective_n_threads) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) + _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances @final cdef void compute_exact_distances(self) nogil: if not self.use_squared_distances: - PairwiseDistancesArgKmin64.compute_exact_distances(self) + PairwiseDistancesArgKmin{{name_suffix}}.compute_exact_distances(self) @final cdef void _parallel_on_X_parallel_init( self, ITYPE_t thread_num, ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_parallel_init(self, thread_num) + PairwiseDistancesArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num) self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) @@ -395,7 +422,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): ITYPE_t X_start, ITYPE_t X_end, ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + PairwiseDistancesArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) @@ -408,7 +435,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: - PairwiseDistancesArgKmin64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + PairwiseDistancesArgKmin{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, X_start, X_end, Y_start, Y_end, @@ -424,7 +451,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): self, ) nogil: cdef ITYPE_t thread_num - PairwiseDistancesArgKmin64._parallel_on_Y_init(self) + PairwiseDistancesArgKmin{{name_suffix}}._parallel_on_Y_init(self) self.gemm_term_computer._parallel_on_Y_init() @@ -435,7 +462,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): ITYPE_t X_start, ITYPE_t X_end, ) nogil: - PairwiseDistancesArgKmin64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + PairwiseDistancesArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) @@ -448,7 +475,7 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: - PairwiseDistancesArgKmin64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + PairwiseDistancesArgKmin{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( self, X_start, X_end, Y_start, Y_end, @@ -499,3 +526,5 @@ cdef class FastEuclideanPairwiseDistancesArgKmin64(PairwiseDistancesArgKmin64): ), j + Y_start, ) + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp similarity index 78% rename from sklearn/metrics/_pairwise_distances_reduction/_base.pxd rename to sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp index 9f6ad45cb839a..d023058df828a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pxd +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp @@ -1,23 +1,40 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DTYPE_t', 'DTYPE'), + ('32', 'cnp.float32_t', 'np.float32') +] + +}} cimport numpy as cnp from cython cimport final -from ._datasets_pair cimport DatasetsPair from ...utils._typedefs cimport ITYPE_t, DTYPE_t cnp.import_array() +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} -cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( - const DTYPE_t[:, ::1] X, +from ._datasets_pair cimport DatasetsPair{{name_suffix}} + +cpdef DTYPE_t[::1] _sqeuclidean_row_norms{{name_suffix}}( + const {{INPUT_DTYPE_t}}[:, ::1] X, ITYPE_t num_threads, ) -cdef class PairwiseDistancesReduction64: - """Base 64bit implementation of PairwiseDistancesReduction.""" +cdef class PairwiseDistancesReduction{{name_suffix}}: + """Base {{name_suffix}}bit implementation of PairwiseDistancesReduction.""" cdef: - readonly DatasetsPair datasets_pair + readonly DatasetsPair{{name_suffix}} datasets_pair # The number of threads that can be used is stored in effective_n_threads. # @@ -126,3 +143,4 @@ cdef class PairwiseDistancesReduction64: cdef void _parallel_on_Y_finalize( self, ) nogil +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp similarity index 85% rename from sklearn/metrics/_pairwise_distances_reduction/_base.pyx rename to sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp index 07506e3616a74..6ff9eb971013a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_base.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp @@ -1,16 +1,34 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DTYPE_t', 'DTYPE'), + ('32', 'cnp.float32_t', 'np.float32') +] + +}} cimport numpy as cnp -import numpy as np +cimport openmp -from sklearn import get_config from cython cimport final +from cython.operator cimport dereference as deref from cython.parallel cimport parallel, prange +from libcpp.vector cimport vector -from ._datasets_pair cimport DatasetsPair from ...utils._cython_blas cimport _dot from ...utils._openmp_helpers cimport _openmp_thread_num from ...utils._typedefs cimport ITYPE_t, DTYPE_t +import numpy as np + from numbers import Integral +from sklearn import get_config from sklearn.utils import check_scalar from ...utils._openmp_helpers import _openmp_effective_n_threads from ...utils._typedefs import ITYPE, DTYPE @@ -44,12 +62,55 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms64( return squared_row_norms -cdef class PairwiseDistancesReduction64: - """Base 64bit implementation of PairwiseDistancesReduction.""" +cpdef DTYPE_t[::1] _sqeuclidean_row_norms32( + const cnp.float32_t[:, ::1] X, + ITYPE_t num_threads, +): + """Compute the squared euclidean norm of the rows of X in parallel. + + This is faster than using np.einsum("ij, ij->i") even when using a single thread. + """ + cdef: + # Casting for X to remove the const qualifier is needed because APIs + # exposed via scipy.linalg.cython_blas aren't reflecting the arguments' + # const qualifier. + # See: https://github.com/scipy/scipy/issues/14262 + cnp.float32_t * X_ptr = &X[0, 0] + ITYPE_t i = 0, j = 0 + ITYPE_t thread_num + ITYPE_t n = X.shape[0] + ITYPE_t d = X.shape[1] + DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE) + + # To upcast the i-th row of X from 32bit to 64bit + vector[vector[DTYPE_t]] X_i_upcast = vector[vector[DTYPE_t]]( + num_threads, vector[DTYPE_t](d) + ) + + with nogil, parallel(num_threads=num_threads): + thread_num = openmp.omp_get_thread_num() + for i in prange(n, schedule='static'): + # Upcasting the i-th row of X from 32bit to 64bit + for j in range(d): + X_i_upcast[thread_num][j] = deref(X_ptr + i * d + j) + + squared_row_norms[i] = _dot( + d, X_i_upcast[thread_num].data(), 1, + X_i_upcast[thread_num].data(), 1, + ) + + return squared_row_norms + +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + +from ._datasets_pair cimport DatasetsPair{{name_suffix}} + +cdef class PairwiseDistancesReduction{{name_suffix}}: + """Base {{name_suffix}}bit implementation of PairwiseDistancesReduction.""" def __init__( self, - DatasetsPair datasets_pair, + DatasetsPair{{name_suffix}} datasets_pair, chunk_size=None, strategy=None, ): @@ -263,7 +324,7 @@ cdef class PairwiseDistancesReduction64: ) nogil: """Compute the pairwise distances on two chunks of X and Y and reduce them. - This is THE core computational method of PairwiseDistanceReductions64. + This is THE core computational method of PairwiseDistanceReductions{{name_suffix}}. This must be implemented in subclasses agnostically from the parallelization strategies. """ @@ -370,3 +431,5 @@ cdef class PairwiseDistancesReduction64: ) nogil: """Update datastructures after executing all the reductions.""" return + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd deleted file mode 100644 index de6458f8c6f26..0000000000000 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd +++ /dev/null @@ -1,21 +0,0 @@ -from ...utils._typedefs cimport DTYPE_t, ITYPE_t -from ...metrics._dist_metrics cimport DistanceMetric - - -cdef class DatasetsPair: - cdef DistanceMetric distance_metric - - cdef ITYPE_t n_samples_X(self) nogil - - cdef ITYPE_t n_samples_Y(self) nogil - - cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil - - cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil - - -cdef class DenseDenseDatasetsPair(DatasetsPair): - cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y - ITYPE_t d diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp new file mode 100644 index 0000000000000..d10f8f493e5f0 --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp @@ -0,0 +1,41 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DistanceMetric', 'DTYPE_t', 'DTYPE'), + ('32', 'DistanceMetric32', 'cnp.float32_t', 'np.float32') +] + +}} +cimport numpy as cnp + +from ...utils._typedefs cimport DTYPE_t, ITYPE_t +from ...metrics._dist_metrics cimport DistanceMetric, DistanceMetric32 + +{{for name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + + +cdef class DatasetsPair{{name_suffix}}: + cdef {{DistanceMetric}} distance_metric + + cdef ITYPE_t n_samples_X(self) nogil + + cdef ITYPE_t n_samples_Y(self) nogil + + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil + + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil + + +cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): + cdef: + const {{INPUT_DTYPE_t}}[:, ::1] X + const {{INPUT_DTYPE_t}}[:, ::1] Y + ITYPE_t d +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp similarity index 78% rename from sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx rename to sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index abef1bed098ed..05364c3ce52b0 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -1,3 +1,18 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DistanceMetric', 'DTYPE_t', 'DTYPE'), + ('32', 'DistanceMetric32', 'cnp.float32_t', 'np.float32') +] + +}} import numpy as np cimport numpy as cnp @@ -7,9 +22,9 @@ from scipy.sparse import issparse from ...utils._typedefs cimport DTYPE_t, ITYPE_t from ...metrics._dist_metrics cimport DistanceMetric -cnp.import_array() +{{for name_suffix, DistanceMetric, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} -cdef class DatasetsPair: +cdef class DatasetsPair{{name_suffix}}: """Abstract class which wraps a pair of datasets (X, Y). This class allows computing distances between a single pair of rows of @@ -18,9 +33,9 @@ cdef class DatasetsPair: The handling of parallelization over chunks to compute the distances and aggregation for several rows at a time is done in dedicated - subclasses of PairwiseDistancesReduction that in-turn rely on - subclasses of DatasetsPair for each pair of rows in the data. The goal - is to make it possible to decouple the generic parallelization and + subclasses of PairwiseDistancesReduction{{name_suffix}} that in-turn rely on + subclasses of DatasetsPair{{name_suffix}} for each pair of rows in the data. + The goal is to make it possible to decouple the generic parallelization and aggregation logic from metric-specific computation as much as possible. @@ -34,7 +49,7 @@ cdef class DatasetsPair: Parameters ---------- - distance_metric: DistanceMetric + distance_metric: DistanceMetric{{name_suffix}} The distance metric responsible for computing distances between two vectors of (X, Y). """ @@ -46,7 +61,7 @@ cdef class DatasetsPair: Y, str metric="euclidean", dict metric_kwargs=None, - ) -> DatasetsPair: + ) -> DatasetsPair{{name_suffix}}: """Return the DatasetsPair implementation for the given arguments. Parameters @@ -72,21 +87,15 @@ cdef class DatasetsPair: Returns ------- - datasets_pair: DatasetsPair - The suited DatasetsPair implementation. + datasets_pair: DatasetsPair{{name_suffix}} + The suited DatasetsPair{{name_suffix}} implementation. """ cdef: - DistanceMetric distance_metric = DistanceMetric.get_metric( + {{DistanceMetric}} distance_metric = {{DistanceMetric}}.get_metric( metric, **(metric_kwargs or {}) ) - if not(X.dtype == Y.dtype == np.float64): - raise ValueError( - f"Only 64bit float datasets are supported at this time, " - f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." - ) - # Metric-specific checks that do not replace nor duplicate `check_array`. distance_metric._validate_data(X) distance_metric._validate_data(Y) @@ -95,9 +104,9 @@ cdef class DatasetsPair: if issparse(X) or issparse(Y): raise ValueError("Only dense datasets are supported for X and Y.") - return DenseDenseDatasetsPair(X, Y, distance_metric) + return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric) - def __init__(self, DistanceMetric distance_metric): + def __init__(self, {{DistanceMetric}} distance_metric): self.distance_metric = distance_metric cdef ITYPE_t n_samples_X(self) nogil: @@ -124,7 +133,7 @@ cdef class DatasetsPair: return -1 @final -cdef class DenseDenseDatasetsPair(DatasetsPair): +cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): """Compute distances between row vectors of two arrays. Parameters @@ -140,7 +149,7 @@ cdef class DenseDenseDatasetsPair(DatasetsPair): between two row vectors of (X, Y). """ - def __init__(self, X, Y, DistanceMetric distance_metric): + def __init__(self, X, Y, {{DistanceMetric}} distance_metric): super().__init__(distance_metric) # Arrays have already been checked self.X = X @@ -162,3 +171,4 @@ cdef class DenseDenseDatasetsPair(DatasetsPair): @final cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil: return self.distance_metric.dist(&self.X[i, 0], &self.Y[j, 0], self.d) +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 1cf670ed35dec..c49c42cff61c9 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -6,9 +6,18 @@ from scipy.sparse import issparse from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING -from ._base import _sqeuclidean_row_norms64 -from ._argkmin import PairwiseDistancesArgKmin64 -from ._radius_neighborhood import PairwiseDistancesRadiusNeighborhood64 +from ._base import ( + _sqeuclidean_row_norms64, + _sqeuclidean_row_norms32, +) +from ._argkmin import ( + PairwiseDistancesArgKmin64, + PairwiseDistancesArgKmin32, +) +from ._radius_neighborhood import ( + PairwiseDistancesRadiusNeighborhood64, + PairwiseDistancesRadiusNeighborhood32, +) from ... import get_config @@ -31,8 +40,12 @@ def sqeuclidean_row_norms(X, num_threads): """ if X.dtype == np.float64: return _sqeuclidean_row_norms64(X, num_threads) + if X.dtype == np.float32: + return _sqeuclidean_row_norms32(X, num_threads) + raise ValueError( - f"Only 64bit float datasets are supported at this time, got: X.dtype={X.dtype}." + "Only float64 or float32 datasets are supported at this time, " + f"got: X.dtype={X.dtype}." ) @@ -79,7 +92,7 @@ def is_usable_for(cls, X, Y, metric) -> bool: ------- True if the PairwiseDistancesReduction can be used, else False. """ - dtypes_validity = X.dtype == Y.dtype == np.float64 + dtypes_validity = X.dtype == Y.dtype and X.dtype in (np.float32, np.float64) c_contiguity = ( hasattr(X, "flags") and X.flags.c_contiguous @@ -247,8 +260,21 @@ def compute( strategy=strategy, return_distance=return_distance, ) + + if X.dtype == Y.dtype == np.float32: + return PairwiseDistancesArgKmin32.compute( + X=X, + Y=Y, + k=k, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + return_distance=return_distance, + ) + raise ValueError( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." ) @@ -386,7 +412,21 @@ def compute( sort_results=sort_results, return_distance=return_distance, ) + + if X.dtype == Y.dtype == np.float32: + return PairwiseDistancesRadiusNeighborhood32.compute( + X=X, + Y=Y, + radius=radius, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + sort_results=sort_results, + return_distance=return_distance, + ) + raise ValueError( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." ) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd.tp similarity index 63% rename from sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd rename to sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd.tp index a1c5bd3a8d80c..5978cfee9ebee 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd +++ b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd.tp @@ -1,11 +1,29 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', False, 'DTYPE_t', 'DTYPE'), + ('32', True, 'cnp.float32_t', 'np.float32') +] + +}} +cimport numpy as cnp + from ...utils._typedefs cimport DTYPE_t, ITYPE_t from libcpp.vector cimport vector +{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} -cdef class GEMMTermComputer64: +cdef class GEMMTermComputer{{name_suffix}}: cdef: - const DTYPE_t[:, ::1] X - const DTYPE_t[:, ::1] Y + const {{INPUT_DTYPE_t}}[:, ::1] X + const {{INPUT_DTYPE_t}}[:, ::1] Y ITYPE_t effective_n_threads ITYPE_t chunks_n_threads @@ -16,6 +34,12 @@ cdef class GEMMTermComputer64: # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM vector[vector[DTYPE_t]] dist_middle_terms_chunks +{{if upcast_to_float64}} + # Buffers for upcasting chunks of X and Y from 32bit to 64bit + vector[vector[DTYPE_t]] X_c_upcast + vector[vector[DTYPE_t]] Y_c_upcast +{{endif}} + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, ITYPE_t X_start, @@ -60,3 +84,5 @@ cdef class GEMMTermComputer64: ITYPE_t Y_end, ITYPE_t thread_num, ) nogil + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp similarity index 57% rename from sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx rename to sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp index 77d752548bb5b..35e57219a96a7 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp @@ -1,3 +1,20 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', False, 'DTYPE_t', 'DTYPE'), + ('32', True, 'cnp.float32_t', 'np.float32') +] + +}} +cimport numpy as cnp + from libcpp.vector cimport vector from ...utils._typedefs cimport DTYPE_t, ITYPE_t @@ -12,7 +29,9 @@ from ...utils._cython_blas cimport ( _gemm, ) -cdef class GEMMTermComputer64: +{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + +cdef class GEMMTermComputer{{name_suffix}}: """Component for `FastEuclidean*` variant wrapping the logic for the call to GEMM. `FastEuclidean*` classes internally compute the squared Euclidean distances between @@ -28,8 +47,8 @@ cdef class GEMMTermComputer64: """ def __init__(self, - const DTYPE_t[:, ::1] X, - const DTYPE_t[:, ::1] Y, + const {{INPUT_DTYPE_t}}[:, ::1] X, + const {{INPUT_DTYPE_t}}[:, ::1] Y, ITYPE_t effective_n_threads, ITYPE_t chunks_n_threads, ITYPE_t dist_middle_terms_chunks_size, @@ -46,6 +65,17 @@ cdef class GEMMTermComputer64: self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads) +{{if upcast_to_float64}} + # We populate the buffer for upcasting chunks of X and Y from 32bit to 64bit. + self.X_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads) + self.Y_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads) + + upcast_buffer_n_elements = self.chunk_size * n_features + + for thread_num in range(self.effective_n_threads): + self.X_c_upcast[thread_num].resize(upcast_buffer_n_elements) + self.Y_c_upcast[thread_num].resize(upcast_buffer_n_elements) +{{endif}} cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, @@ -55,7 +85,19 @@ cdef class GEMMTermComputer64: ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: +{{if upcast_to_float64}} + cdef: + ITYPE_t i, j + ITYPE_t n_chunk_samples = Y_end - Y_start + + # Upcasting Y_c=Y[Y_start:Y_end, :] from float32 to float64 + for i in range(n_chunk_samples): + for j in range(self.n_features): + self.Y_c_upcast[thread_num][i * self.n_features + j] = self.Y[Y_start + i, j] +{{else}} return +{{endif}} + cdef void _parallel_on_X_parallel_init(self, ITYPE_t thread_num) nogil: self.dist_middle_terms_chunks[thread_num].resize(self.dist_middle_terms_chunks_size) @@ -66,7 +108,18 @@ cdef class GEMMTermComputer64: ITYPE_t X_start, ITYPE_t X_end, ) nogil: +{{if upcast_to_float64}} + cdef: + ITYPE_t i, j + ITYPE_t n_chunk_samples = X_end - X_start + + # Upcasting X_c=X[X_start:X_end, :] from float32 to float64 + for i in range(n_chunk_samples): + for j in range(self.n_features): + self.X_c_upcast[thread_num][i * self.n_features + j] = self.X[X_start + i, j] +{{else}} return +{{endif}} cdef void _parallel_on_Y_init(self) nogil: for thread_num in range(self.chunks_n_threads): @@ -80,7 +133,18 @@ cdef class GEMMTermComputer64: ITYPE_t X_start, ITYPE_t X_end, ) nogil: +{{if upcast_to_float64}} + cdef: + ITYPE_t i, j + ITYPE_t n_chunk_samples = X_end - X_start + + # Upcasting X_c=X[X_start:X_end, :] from float32 to float64 + for i in range(n_chunk_samples): + for j in range(self.n_features): + self.X_c_upcast[thread_num][i * self.n_features + j] = self.X[X_start + i, j] +{{else}} return +{{endif}} cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( self, @@ -90,7 +154,18 @@ cdef class GEMMTermComputer64: ITYPE_t Y_end, ITYPE_t thread_num ) nogil: +{{if upcast_to_float64}} + cdef: + ITYPE_t i, j + ITYPE_t n_chunk_samples = Y_end - Y_start + + # Upcasting Y_c=Y[Y_start:Y_end, :] from float32 to float64 + for i in range(n_chunk_samples): + for j in range(self.n_features): + self.Y_c_upcast[thread_num][i * self.n_features + j] = self.Y[Y_start + i, j] +{{else}} return +{{endif}} cdef DTYPE_t * _compute_distances_on_chunks( self, @@ -103,8 +178,8 @@ cdef class GEMMTermComputer64: cdef: ITYPE_t i, j DTYPE_t squared_dist_i_j - const DTYPE_t[:, ::1] X_c = self.X[X_start:X_end, :] - const DTYPE_t[:, ::1] Y_c = self.Y[Y_start:Y_end, :] + const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :] + const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :] DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data() # Careful: LDA, LDB and LDC are given for F-ordered arrays @@ -119,11 +194,16 @@ cdef class GEMMTermComputer64: ITYPE_t n = Y_c.shape[0] ITYPE_t K = X_c.shape[1] DTYPE_t alpha = - 2. +{{if upcast_to_float64}} + DTYPE_t * A = self.X_c_upcast[thread_num].data() + DTYPE_t * B = self.Y_c_upcast[thread_num].data() +{{else}} # Casting for A and B to remove the const is needed because APIs exposed via # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier. # See: https://github.com/scipy/scipy/issues/14262 DTYPE_t * A = &X_c[0, 0] DTYPE_t * B = &Y_c[0, 0] +{{endif}} ITYPE_t lda = X_c.shape[1] ITYPE_t ldb = X_c.shape[1] DTYPE_t beta = 0. @@ -133,3 +213,5 @@ cdef class GEMMTermComputer64: _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc) return dist_middle_terms + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd.tp similarity index 68% rename from sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd rename to sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd.tp index 737e6888a8a55..639b48c5b64ae 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd.tp @@ -1,14 +1,24 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DTYPE_t', 'DTYPE'), + ('32', 'cnp.float32_t', 'np.float32') +] + +}} cimport numpy as cnp from libcpp.memory cimport shared_ptr from libcpp.vector cimport vector from cython cimport final -from ._base cimport ( - PairwiseDistancesReduction64, -) -from ._gemm_term_computer cimport GEMMTermComputer64 - from ...utils._typedefs cimport ITYPE_t, DTYPE_t cnp.import_array() @@ -31,14 +41,18 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( ) ##################### +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} -cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesRadiusNeighborhood .""" +from ._base cimport PairwiseDistancesReduction{{name_suffix}} +from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} + +cdef class PairwiseDistancesRadiusNeighborhood{{name_suffix}}(PairwiseDistancesReduction{{name_suffix}}): + """{{name_suffix}}bit implementation of PairwiseDistancesRadiusNeighborhood .""" cdef: DTYPE_t radius - # DistanceMetric compute rank-preserving surrogate distance via rdist + # DistanceMetric{{name_suffix}} compute rank-preserving surrogate distance via rdist # which are proxies necessitating less computations. # We get the equivalent for the radius to be able to compare it against # vectors' rank-preserving surrogate distances. @@ -79,11 +93,13 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): ) nogil -cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" +cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood{{name_suffix}}(PairwiseDistancesRadiusNeighborhood{{name_suffix}}): + """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistancesRadiusNeighborhood.""" cdef: - GEMMTermComputer64 gemm_term_computer + GEMMTermComputer{{name_suffix}} gemm_term_computer const DTYPE_t[::1] X_norm_squared const DTYPE_t[::1] Y_norm_squared bint use_squared_distances + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx.tp similarity index 83% rename from sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx rename to sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx.tp index db2c22e89d06d..a5890fd80c5aa 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx.tp @@ -1,3 +1,18 @@ +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + # + # We also use the float64 dtype and C-type names as defined in + # `sklearn.utils._typedefs` to maintain consistency. + # + ('64', 'DTYPE_t', 'DTYPE'), + ('32', 'cnp.float32_t', 'np.float32') +] + +}} cimport numpy as cnp import numpy as np import warnings @@ -8,26 +23,14 @@ from cython cimport final from cython.operator cimport dereference as deref from cython.parallel cimport parallel, prange -from ._base cimport ( - PairwiseDistancesReduction64, - _sqeuclidean_row_norms64 -) - -from ._datasets_pair cimport ( - DatasetsPair, - DenseDenseDatasetsPair, -) - -from ._gemm_term_computer cimport GEMMTermComputer64 - from ...utils._sorting cimport simultaneous_sort from ...utils._typedefs cimport ITYPE_t, DTYPE_t from ...utils._vector_sentinel cimport vector_to_nd_array from numbers import Real from scipy.sparse import issparse -from sklearn.utils import check_scalar, _in_unstable_openblas_configuration -from sklearn.utils.fixes import threadpool_limits +from ...utils import check_array, check_scalar, _in_unstable_openblas_configuration +from ...utils.fixes import threadpool_limits cnp.import_array() @@ -53,10 +56,23 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays( return nd_arrays_of_nd_arrays ##################### +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} +from ._base cimport ( + PairwiseDistancesReduction{{name_suffix}}, + _sqeuclidean_row_norms{{name_suffix}} +) + +from ._datasets_pair cimport ( + DatasetsPair{{name_suffix}}, + DenseDenseDatasetsPair{{name_suffix}}, +) + +from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}} -cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): - """64bit implementation of PairwiseDistancesRadiusNeighborhood.""" + +cdef class PairwiseDistancesRadiusNeighborhood{{name_suffix}}(PairwiseDistancesReduction{{name_suffix}}): + """{{name_suffix}}bit implementation of PairwiseDistancesRadiusNeighborhood.""" @classmethod def compute( @@ -75,7 +91,7 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): This classmethod is responsible for introspecting the arguments values to dispatch to the most appropriate implementation of - :class:`PairwiseDistancesRadiusNeighborhood64`. + :class:`PairwiseDistancesRadiusNeighborhood{{name_suffix}}`. This allows decoupling the API entirely from the implementation details whilst maintaining RAII: all temporarily allocated datastructures necessary @@ -94,7 +110,7 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): # at time to leverage a call to the BLAS GEMM routine as explained # in more details in GEMMTermComputer docstring. use_squared_distances = metric == "sqeuclidean" - pda = FastEuclideanPairwiseDistancesRadiusNeighborhood64( + pda = FastEuclideanPairwiseDistancesRadiusNeighborhood{{name_suffix}}( X=X, Y=Y, radius=radius, use_squared_distances=use_squared_distances, chunk_size=chunk_size, @@ -105,8 +121,8 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): else: # Fall back on a generic implementation that handles most scipy # metrics by computing the distances between 2 vectors at a time. - pda = PairwiseDistancesRadiusNeighborhood64( - datasets_pair=DatasetsPair.get_for(X, Y, metric, metric_kwargs), + pda = PairwiseDistancesRadiusNeighborhood{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), radius=radius, chunk_size=chunk_size, metric_kwargs=metric_kwargs, @@ -127,7 +143,7 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): def __init__( self, - DatasetsPair datasets_pair, + DatasetsPair{{name_suffix}} datasets_pair, DTYPE_t radius, chunk_size=None, strategy=None, @@ -317,12 +333,12 @@ cdef class PairwiseDistancesRadiusNeighborhood64(PairwiseDistancesReduction64): ) -cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesRadiusNeighborhood64): - """EuclideanDistance-specialized 64bit implementation for PairwiseDistancesRadiusNeighborhood.""" +cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood{{name_suffix}}(PairwiseDistancesRadiusNeighborhood{{name_suffix}}): + """EuclideanDistance-specialized {{name_suffix}}bit implementation for PairwiseDistancesRadiusNeighborhood.""" @classmethod def is_usable_for(cls, X, Y, metric) -> bool: - return (PairwiseDistancesRadiusNeighborhood64.is_usable_for(X, Y, metric) + return (PairwiseDistancesRadiusNeighborhood{{name_suffix}}.is_usable_for(X, Y, metric) and not _in_unstable_openblas_configuration()) def __init__( @@ -350,19 +366,22 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR super().__init__( # The datasets pair here is used for exact distances computations - datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"), + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"), radius=radius, chunk_size=chunk_size, strategy=strategy, sort_results=sort_results, metric_kwargs=metric_kwargs, ) - # X and Y are checked by the DatasetsPair implemented as a DenseDenseDatasetsPair + # X and Y are checked by the DatasetsPair{{name_suffix}} implemented + # as a DenseDenseDatasetsPair{{name_suffix}} cdef: - DenseDenseDatasetsPair datasets_pair = self.datasets_pair + DenseDenseDatasetsPair{{name_suffix}} datasets_pair = ( + self.datasets_pair + ) ITYPE_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk - self.gemm_term_computer = GEMMTermComputer64( + self.gemm_term_computer = GEMMTermComputer{{name_suffix}}( datasets_pair.X, datasets_pair.Y, self.effective_n_threads, @@ -373,14 +392,19 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR ) if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs: - self.Y_norm_squared = metric_kwargs.pop("Y_norm_squared") + self.Y_norm_squared = check_array( + metric_kwargs.pop("Y_norm_squared"), + ensure_2d=False, + input_name="Y_norm_squared", + dtype=np.float64 + ) else: - self.Y_norm_squared = _sqeuclidean_row_norms64(datasets_pair.Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.Y, self.effective_n_threads) # Do not recompute norms if datasets are identical. self.X_norm_squared = ( self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms64(datasets_pair.X, self.effective_n_threads) + _sqeuclidean_row_norms{{name_suffix}}(datasets_pair.X, self.effective_n_threads) ) self.use_squared_distances = use_squared_distances @@ -394,7 +418,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR self, ITYPE_t thread_num, ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_parallel_init(self, thread_num) + PairwiseDistancesRadiusNeighborhood{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num) self.gemm_term_computer._parallel_on_X_parallel_init(thread_num) @final @@ -404,7 +428,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR ITYPE_t X_start, ITYPE_t X_end, ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) + PairwiseDistancesRadiusNeighborhood{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end) self.gemm_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end) @final @@ -416,7 +440,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + PairwiseDistancesRadiusNeighborhood{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks( self, X_start, X_end, Y_start, Y_end, @@ -431,7 +455,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR self, ) nogil: cdef ITYPE_t thread_num - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_init(self) + PairwiseDistancesRadiusNeighborhood{{name_suffix}}._parallel_on_Y_init(self) self.gemm_term_computer._parallel_on_Y_init() @final @@ -441,7 +465,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR ITYPE_t X_start, ITYPE_t X_end, ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) + PairwiseDistancesRadiusNeighborhood{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end) self.gemm_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end) @final @@ -453,7 +477,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR ITYPE_t Y_end, ITYPE_t thread_num, ) nogil: - PairwiseDistancesRadiusNeighborhood64._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + PairwiseDistancesRadiusNeighborhood{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( self, X_start, X_end, Y_start, Y_end, @@ -466,7 +490,7 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR @final cdef void compute_exact_distances(self) nogil: if not self.use_squared_distances: - PairwiseDistancesRadiusNeighborhood64.compute_exact_distances(self) + PairwiseDistancesRadiusNeighborhood{{name_suffix}}.compute_exact_distances(self) @final cdef void _compute_and_reduce_distances_on_chunks( @@ -501,3 +525,5 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood64(PairwiseDistancesR if squared_dist_i_j <= self.r_radius: deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) + +{{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/setup.py b/sklearn/metrics/_pairwise_distances_reduction/setup.py index 0d8c2c8ce33de..f55ec659b5821 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/setup.py +++ b/sklearn/metrics/_pairwise_distances_reduction/setup.py @@ -3,6 +3,8 @@ import numpy as np from numpy.distutils.misc_util import Configuration +from sklearn._build_utils import gen_from_templates + def configuration(parent_package="", top_path=None): config = Configuration("_pairwise_distances_reduction", parent_package, top_path) @@ -10,6 +12,21 @@ def configuration(parent_package="", top_path=None): if os.name == "posix": libraries.append("m") + templates = [ + "sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp", + "sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp", + "sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx.tp", + "sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd.tp", + "sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp", + "sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp", + "sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp", + "sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp", + "sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx.tp", + "sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd.tp", + ] + + gen_from_templates(templates) + cython_sources = [ "_datasets_pair.pyx", "_gemm_term_computer.pyx", diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index b8b7d584aee56..dfc2c79bcf41d 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -90,12 +90,18 @@ def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices, rtol=1 def relative_rounding(scalar, n_significant_digits): """Round a scalar to a number of significant digits relatively to its value.""" + if scalar == 0: + return 0.0 magnitude = int(floor(log10(abs(scalar)))) + 1 return round(scalar, n_significant_digits - magnitude) def test_relative_rounding(): + assert relative_rounding(0, 1) == 0.0 + assert relative_rounding(0, 10) == 0.0 + assert relative_rounding(0, 123456) == 0.0 + assert relative_rounding(123456789, 0) == 0 assert relative_rounding(123456789, 2) == 120000000 assert relative_rounding(123456789, 3) == 123000000 @@ -515,6 +521,11 @@ def test_pairwise_distances_reduction_is_usable_for(): assert PairwiseDistancesReduction.is_usable_for( X.astype(np.float64), X.astype(np.float64), metric ) + + assert PairwiseDistancesReduction.is_usable_for( + X.astype(np.float32), X.astype(np.float32), metric + ) + assert not PairwiseDistancesReduction.is_usable_for( X.astype(np.int64), Y.astype(np.int64), metric ) @@ -539,7 +550,7 @@ def test_argkmin_factory_method_wrong_usages(): metric = "euclidean" msg = ( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " "got: X.dtype=float32 and Y.dtype=float64" ) with pytest.raises(ValueError, match=msg): @@ -548,7 +559,7 @@ def test_argkmin_factory_method_wrong_usages(): ) msg = ( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " "got: X.dtype=float64 and Y.dtype=int32" ) with pytest.raises(ValueError, match=msg): @@ -597,7 +608,7 @@ def test_radius_neighborhood_factory_method_wrong_usages(): metric = "euclidean" msg = ( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " "got: X.dtype=float32 and Y.dtype=float64" ) with pytest.raises( @@ -609,7 +620,7 @@ def test_radius_neighborhood_factory_method_wrong_usages(): ) msg = ( - "Only 64bit float datasets are supported at this time, " + "Only float64 or float32 datasets pairs are supported at this time, " "got: X.dtype=float64 and Y.dtype=int32" ) with pytest.raises( @@ -659,13 +670,14 @@ def test_radius_neighborhood_factory_method_wrong_usages(): "PairwiseDistancesReduction", [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_chunk_size_agnosticism( global_random_seed, PairwiseDistancesReduction, n_samples, chunk_size, + dtype, n_features=100, - dtype=np.float64, ): # Results must not depend on the chunk size rng = np.random.RandomState(global_random_seed) @@ -676,11 +688,13 @@ def test_chunk_size_agnosticism( if PairwiseDistancesReduction is PairwiseDistancesArgKmin: parameter = 10 check_parameters = {} + compute_parameters = {} else: # Scaling the radius slightly with the numbers of dimensions radius = 10 ** np.log(n_features) parameter = radius check_parameters = {"radius": radius} + compute_parameters = {"sort_results": True} ref_dist, ref_indices = PairwiseDistancesReduction.compute( X, @@ -688,6 +702,7 @@ def test_chunk_size_agnosticism( parameter, metric="manhattan", return_distance=True, + **compute_parameters, ) dist, indices = PairwiseDistancesReduction.compute( @@ -697,6 +712,7 @@ def test_chunk_size_agnosticism( chunk_size=chunk_size, metric="manhattan", return_distance=True, + **compute_parameters, ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( @@ -710,13 +726,14 @@ def test_chunk_size_agnosticism( "PairwiseDistancesReduction", [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_n_threads_agnosticism( global_random_seed, PairwiseDistancesReduction, n_samples, chunk_size, + dtype, n_features=100, - dtype=np.float64, ): # Results must not depend on the number of threads rng = np.random.RandomState(global_random_seed) @@ -727,22 +744,29 @@ def test_n_threads_agnosticism( if PairwiseDistancesReduction is PairwiseDistancesArgKmin: parameter = 10 check_parameters = {} + compute_parameters = {} else: # Scaling the radius slightly with the numbers of dimensions radius = 10 ** np.log(n_features) parameter = radius check_parameters = {"radius": radius} + compute_parameters = {"sort_results": True} ref_dist, ref_indices = PairwiseDistancesReduction.compute( X, Y, parameter, return_distance=True, + **compute_parameters, ) with threadpoolctl.threadpool_limits(limits=1, user_api="openmp"): dist, indices = PairwiseDistancesReduction.compute( - X, Y, parameter, return_distance=True + X, + Y, + parameter, + return_distance=True, + **compute_parameters, ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( @@ -758,13 +782,14 @@ def test_n_threads_agnosticism( "PairwiseDistancesReduction", [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_strategies_consistency( global_random_seed, PairwiseDistancesReduction, metric, n_samples, + dtype, n_features=10, - dtype=np.float64, ): rng = np.random.RandomState(global_random_seed) @@ -780,11 +805,13 @@ def test_strategies_consistency( if PairwiseDistancesReduction is PairwiseDistancesArgKmin: parameter = 10 check_parameters = {} + compute_parameters = {} else: # Scaling the radius slightly with the numbers of dimensions radius = 10 ** np.log(n_features) parameter = radius check_parameters = {"radius": radius} + compute_parameters = {"sort_results": True} dist_par_X, indices_par_X = PairwiseDistancesReduction.compute( X, @@ -799,6 +826,7 @@ def test_strategies_consistency( chunk_size=n_samples // 4, strategy="parallel_on_X", return_distance=True, + **compute_parameters, ) dist_par_Y, indices_par_Y = PairwiseDistancesReduction.compute( @@ -814,6 +842,7 @@ def test_strategies_consistency( chunk_size=n_samples // 4, strategy="parallel_on_Y", return_distance=True, + **compute_parameters, ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( @@ -829,16 +858,25 @@ def test_strategies_consistency( @pytest.mark.parametrize("translation", [0, 1e6]) @pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) @pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_pairwise_distances_argkmin( global_random_seed, n_features, translation, metric, strategy, + dtype, n_samples=100, k=10, - dtype=np.float64, ): + # TODO: can we easily fix this discrepancy? + edge_cases = [ + (np.float32, "chebyshev", 1000000.0), + (np.float32, "chebyshev", 1000000.0), + ] + if (dtype, metric, translation) in edge_cases: + pytest.xfail("Numerical differences lead to small differences in results.") + rng = np.random.RandomState(global_random_seed) spread = 1000 X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread @@ -892,14 +930,15 @@ def test_pairwise_distances_argkmin( @pytest.mark.parametrize("translation", [0, 1e6]) @pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS) @pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y")) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_pairwise_distances_radius_neighbors( global_random_seed, n_features, translation, metric, strategy, + dtype, n_samples=100, - dtype=np.float64, ): rng = np.random.RandomState(global_random_seed) spread = 1000 @@ -955,12 +994,13 @@ def test_pairwise_distances_radius_neighbors( [PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood], ) @pytest.mark.parametrize("metric", ["manhattan", "euclidean"]) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_memmap_backed_data( metric, PairwiseDistancesReduction, + dtype, n_samples=512, n_features=100, - dtype=np.float64, ): # Results must not depend on the datasets writability rng = np.random.RandomState(0) @@ -974,11 +1014,13 @@ def test_memmap_backed_data( if PairwiseDistancesReduction is PairwiseDistancesArgKmin: parameter = 10 check_parameters = {} + compute_parameters = {} else: # Scaling the radius slightly with the numbers of dimensions radius = 10 ** np.log(n_features) parameter = radius check_parameters = {"radius": radius} + compute_parameters = {"sort_results": True} ref_dist, ref_indices = PairwiseDistancesReduction.compute( X, @@ -986,6 +1028,7 @@ def test_memmap_backed_data( parameter, metric=metric, return_distance=True, + **compute_parameters, ) dist_mm, indices_mm = PairwiseDistancesReduction.compute( @@ -994,6 +1037,7 @@ def test_memmap_backed_data( parameter, metric=metric, return_distance=True, + **compute_parameters, ) ASSERT_RESULT[(PairwiseDistancesReduction, dtype)]( @@ -1004,12 +1048,13 @@ def test_memmap_backed_data( @pytest.mark.parametrize("n_samples", [100, 1000]) @pytest.mark.parametrize("n_features", [5, 10, 100]) @pytest.mark.parametrize("num_threads", [1, 2, 8]) +@pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_sqeuclidean_row_norms( global_random_seed, n_samples, n_features, num_threads, - dtype=np.float64, + dtype, ): rng = np.random.RandomState(global_random_seed) spread = 100