From 0243c1d2bb94bc5a39fbaece0483c02567b5b629 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 25 Nov 2022 15:49:58 +0100 Subject: [PATCH 01/11] ENH Add the fused CSR dense case for Euclidean Specializations --- .../_argkmin.pyx.tp | 5 +- .../_dispatcher.py | 21 +-- .../_middle_term_computer.pxd.tp | 37 +++++ .../_middle_term_computer.pyx.tp | 148 +++++++++++++++++- .../_radius_neighbors.pyx.tp | 5 +- .../test_pairwise_distances_reduction.py | 78 +++++---- 6 files changed, 235 insertions(+), 59 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index b8afe5c3cd5f8..58bac2aaf35ca 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -61,10 +61,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): No instance should directly be created outside of this class method. """ - if ( - metric in ("euclidean", "sqeuclidean") - and not (issparse(X) ^ issparse(Y)) # "^" is the XOR operator - ): + if metric in ("euclidean", "sqeuclidean"): # Specialized implementation of ArgKmin for the Euclidean distance # for the dense-dense and sparse-sparse cases. # This implementation computes the distances by chunk using diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 62403d1c334f0..576cc64ff5295 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -119,26 +119,7 @@ def is_valid_sparse_matrix(X): and metric in cls.valid_metrics() ) - # The other joblib-based back-end might be more efficient on fused sparse-dense - # datasets' pairs on metric="(sq)euclidean" for some configurations because it - # uses the Squared Euclidean matrix decomposition, i.e.: - # - # ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - # calling efficient sparse-dense routines for matrix and vectors multiplication - # implemented in SciPy we do not use yet here. - # See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa - # TODO: implement specialisation for (sq)euclidean on fused sparse-dense - # using sparse-dense routines for matrix-vector multiplications. - # Currently, only dense-dense and sparse-sparse are optimized for - # the Euclidean case. - fused_sparse_dense_euclidean_case_guard = not ( - (is_valid_sparse_matrix(X) ^ is_valid_sparse_matrix(Y)) # "^" is XOR - and isinstance(metric, str) - and "euclidean" in metric - ) - - return is_usable and fused_sparse_dense_euclidean_case_guard + return is_usable @classmethod @abstractmethod diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp index e6ef5de2727b5..93a46770ce817 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp @@ -186,4 +186,41 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam ) nogil +cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}): + cdef: + const DTYPE_t[:] X_data + const SPARSE_INDEX_TYPE_t[:] X_indices + const SPARSE_INDEX_TYPE_t[:] X_indptr + + const DTYPE_t[:, ::1] Y + + bint c_ordered_middle_term + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num + ) nogil + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num + ) nogil + + cdef DTYPE_t * _compute_dist_middle_terms( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil + {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index 4eb3733c42bcf..476a2e65fe85a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -78,6 +78,37 @@ cdef void _middle_term_sparse_sparse_64( if X_i_col_idx == Y_j_col_idx: D[k] += -2 * X_data[X_i_ptr] * Y_data[Y_j_ptr] +# TODO: compare this routine with the similar ones in SciPy, especially +# `csr_matvects` which might implement a better algorithm. +# See: https://github.com/scipy/scipy/blob/e58292e066ba2cb2f3d1e0563ca9314ff1f4f311/scipy/sparse/sparsetools/csr.h#L1139-L1175 # noqa +cdef void _middle_term_sparse_dense_64( + const DTYPE_t[:] X_data, + const SPARSE_INDEX_TYPE_t[:] X_indices, + const SPARSE_INDEX_TYPE_t[:] X_indptr, + ITYPE_t X_start, + ITYPE_t X_end, + const DTYPE_t[:, ::1] Y, + ITYPE_t Y_start, + ITYPE_t Y_end, + bint c_ordered_middle_term, + DTYPE_t * D, +) nogil: + # This routine assumes that D points to the first element of a + # zeroed buffer of length at least equal to n_X × n_Y, conceptually + # representing a 2-d C-ordered array. + cdef: + ITYPE_t i, j, k + ITYPE_t n_X = X_end - X_start + ITYPE_t n_Y = Y_end - Y_start + ITYPE_t X_i_col_idx, X_i_ptr, Y_j_col_idx, Y_j_ptr + + for i in range(n_X): + for j in range(n_Y): + k = i * n_Y + j if c_ordered_middle_term else j * n_X + i + for X_i_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]): + X_i_col_idx = X_indices[X_i_ptr] + D[k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx] + {{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} @@ -111,7 +142,7 @@ cdef class MiddleTermComputer{{name_suffix}}: n_features, chunk_size, ) -> MiddleTermComputer{{name_suffix}}: - """Return the DatasetsPair implementation for the given arguments. + """Return the MiddleTermComputer implementation for the given arguments. Parameters ---------- @@ -151,12 +182,34 @@ cdef class MiddleTermComputer{{name_suffix}}: n_features, chunk_size, ) - + if X_is_sparse and not Y_is_sparse: + return SparseDenseMiddleTermComputer{{name_suffix}}( + X, + # TODO: remove cast + Y.astype(np.float64, copy=False), + effective_n_threads, + chunks_n_threads, + dist_middle_terms_chunks_size, + n_features, + chunk_size, + c_ordered_middle_term=True + ) + if not X_is_sparse and Y_is_sparse: + return SparseDenseMiddleTermComputer{{name_suffix}}( + Y, + # TODO: remove cast + X.astype(np.float64, copy=False), + effective_n_threads, + chunks_n_threads, + dist_middle_terms_chunks_size, + n_features, + chunk_size, + c_ordered_middle_term=False, + ) raise NotImplementedError( - "X and Y must be both CSR sparse matrices or both numpy arrays." + "X and Y must be CSR sparse matrices or numpy arrays." ) - @classmethod def unpack_csr_matrix(cls, X: csr_matrix): """Ensure that the CSR matrix is indexed with SPARSE_INDEX_TYPE.""" @@ -494,5 +547,92 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam return dist_middle_terms +cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}): + """Middle term of the Euclidean distance between chunks of a CSR matrix and a np.ndarray. + + The logic of the computation is wrapped in the routine _middle_term_sparse_dense_64. + This routine iterates over the data, indices and indptr arrays of the sparse matrices + without densifying them. + """ + + def __init__( + self, + X, + Y, + ITYPE_t effective_n_threads, + ITYPE_t chunks_n_threads, + ITYPE_t dist_middle_terms_chunks_size, + ITYPE_t n_features, + ITYPE_t chunk_size, + bint c_ordered_middle_term, + ): + super().__init__( + effective_n_threads, + chunks_n_threads, + dist_middle_terms_chunks_size, + n_features, + chunk_size, + ) + self.X_data, self.X_indices, self.X_indptr = self.unpack_csr_matrix(X) + self.Y = Y + self.c_ordered_middle_term = c_ordered_middle_term + + cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + # Flush the thread dist_middle_terms_chunks to 0.0 + fill( + self.dist_middle_terms_chunks[thread_num].begin(), + self.dist_middle_terms_chunks[thread_num].end(), + 0.0, + ) + + cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + # Flush the thread dist_middle_terms_chunks to 0.0 + fill( + self.dist_middle_terms_chunks[thread_num].begin(), + self.dist_middle_terms_chunks[thread_num].end(), + 0.0, + ) + + cdef DTYPE_t * _compute_dist_middle_terms( + self, + ITYPE_t X_start, + ITYPE_t X_end, + ITYPE_t Y_start, + ITYPE_t Y_end, + ITYPE_t thread_num, + ) nogil: + cdef: + DTYPE_t *dist_middle_terms = ( + self.dist_middle_terms_chunks[thread_num].data() + ) + + _middle_term_sparse_dense_64( + self.X_data, + self.X_indices, + self.X_indptr, + X_start if self.c_ordered_middle_term else Y_start, + X_end if self.c_ordered_middle_term else Y_end, + self.Y, + Y_start if self.c_ordered_middle_term else X_start, + Y_end if self.c_ordered_middle_term else X_end, + self.c_ordered_middle_term, + dist_middle_terms, + ) + + return dist_middle_terms {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index b3f20cac3ea08..e9070b88d99d2 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -82,10 +82,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) No instance should directly be created outside of this class method. """ - if ( - metric in ("euclidean", "sqeuclidean") - and not (issparse(X) ^ issparse(Y)) # "^" is XOR - ): + if metric in ("euclidean", "sqeuclidean"): # Specialized implementation of RadiusNeighbors for the Euclidean # distance for the dense-dense and sparse-sparse cases. # This implementation computes the distances by chunk using diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 4fe8013cd3602..7a1df17c9d509 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,4 +1,3 @@ -import itertools import re import warnings from collections import defaultdict @@ -553,15 +552,11 @@ def test_pairwise_distances_reduction_is_usable_for(): np.asfortranarray(X), Y, metric ) - # We prefer not to use those implementations for fused sparse-dense when - # metric="(sq)euclidean" because it's not yet the most efficient one on - # all configurations of datasets. - # See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa - # TODO: implement specialisation for (sq)euclidean on fused sparse-dense - # using sparse-dense routines for matrix-vector multiplications. - assert not BaseDistancesReductionDispatcher.is_usable_for( - X_csr, Y, metric="euclidean" + assert BaseDistancesReductionDispatcher.is_usable_for(X_csr, Y, metric="euclidean") + assert BaseDistancesReductionDispatcher.is_usable_for( + X, Y_csr, metric="sqeuclidean" ) + assert BaseDistancesReductionDispatcher.is_usable_for( X_csr, Y_csr, metric="sqeuclidean" ) @@ -906,24 +901,53 @@ def test_format_agnosticism( **compute_parameters, ) - for _X, _Y in itertools.product((X, X_csr), (Y, Y_csr)): - if _X is X and _Y is Y: - continue - dist, indices = Dispatcher.compute( - _X, - _Y, - parameter, - chunk_size=50, - return_distance=True, - **compute_parameters, - ) - ASSERT_RESULT[(Dispatcher, dtype)]( - dist_dense, - dist, - indices_dense, - indices, - **check_parameters, - ) + dist, indices = Dispatcher.compute( + X_csr, + Y_csr, + parameter, + chunk_size=50, + return_distance=True, + **compute_parameters, + ) + ASSERT_RESULT[(Dispatcher, dtype)]( + dist_dense, + dist, + indices_dense, + indices, + **check_parameters, + ) + + dist, indices = Dispatcher.compute( + X_csr, + Y, + parameter, + chunk_size=50, + return_distance=True, + **compute_parameters, + ) + ASSERT_RESULT[(Dispatcher, dtype)]( + dist_dense, + dist, + indices_dense, + indices, + **check_parameters, + ) + + dist, indices = Dispatcher.compute( + X, + Y_csr, + parameter, + chunk_size=50, + return_distance=True, + **compute_parameters, + ) + ASSERT_RESULT[(Dispatcher, dtype)]( + dist_dense, + dist, + indices_dense, + indices, + **check_parameters, + ) @pytest.mark.parametrize( From 303bc906764e65b9333dbabb80f41b7e851aa7df Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 27 Feb 2023 11:52:12 +0100 Subject: [PATCH 02/11] Remove the upcast from float32 to float64 --- .../_middle_term_computer.pxd.tp | 2 +- .../_middle_term_computer.pyx.tp | 20 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp index 0574a63404af5..35e70a12c9732 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp @@ -192,7 +192,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name const SPARSE_INDEX_TYPE_t[:] X_indices const SPARSE_INDEX_TYPE_t[:] X_indptr - const DTYPE_t[:, ::1] Y + const {{INPUT_DTYPE_t}}[:, ::1] Y bint c_ordered_middle_term diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index a6dcf1bc55852..393d4a02c9594 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -70,16 +70,19 @@ cdef void _middle_term_sparse_sparse_64( if x_col == y_col: D[k] += -2 * X_data[x_ptr] * Y_data[y_ptr] + +{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + # TODO: compare this routine with the similar ones in SciPy, especially # `csr_matvects` which might implement a better algorithm. # See: https://github.com/scipy/scipy/blob/e58292e066ba2cb2f3d1e0563ca9314ff1f4f311/scipy/sparse/sparsetools/csr.h#L1139-L1175 # noqa -cdef void _middle_term_sparse_dense_64( +cdef void _middle_term_sparse_dense_{{name_suffix}}( const DTYPE_t[:] X_data, const SPARSE_INDEX_TYPE_t[:] X_indices, const SPARSE_INDEX_TYPE_t[:] X_indptr, ITYPE_t X_start, ITYPE_t X_end, - const DTYPE_t[:, ::1] Y, + const {{INPUT_DTYPE_t}}[:, ::1] Y, ITYPE_t Y_start, ITYPE_t Y_end, bint c_ordered_middle_term, @@ -102,9 +105,6 @@ cdef void _middle_term_sparse_dense_64( D[k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx] -{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} - - cdef class MiddleTermComputer{{name_suffix}}: """Helper class to compute a Euclidean distance matrix in chunks. @@ -177,8 +177,7 @@ cdef class MiddleTermComputer{{name_suffix}}: if X_is_sparse and not Y_is_sparse: return SparseDenseMiddleTermComputer{{name_suffix}}( X, - # TODO: remove cast - Y.astype(np.float64, copy=False), + Y, effective_n_threads, chunks_n_threads, dist_middle_terms_chunks_size, @@ -189,8 +188,7 @@ cdef class MiddleTermComputer{{name_suffix}}: if not X_is_sparse and Y_is_sparse: return SparseDenseMiddleTermComputer{{name_suffix}}( Y, - # TODO: remove cast - X.astype(np.float64, copy=False), + X, effective_n_threads, chunks_n_threads, dist_middle_terms_chunks_size, @@ -542,7 +540,7 @@ cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{nam cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}): """Middle term of the Euclidean distance between chunks of a CSR matrix and a np.ndarray. - The logic of the computation is wrapped in the routine _middle_term_sparse_dense_64. + The logic of the computation is wrapped in the routine _middle_term_sparse_dense_{{name_suffix}}. This routine iterates over the data, indices and indptr arrays of the sparse matrices without densifying them. """ @@ -612,7 +610,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name self.dist_middle_terms_chunks[thread_num].data() ) - _middle_term_sparse_dense_64( + _middle_term_sparse_dense_{{name_suffix}}( self.X_data, self.X_indices, self.X_indptr, From 8576390962c677b21ba6c6801b2c849bb71e476c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 27 Feb 2023 12:09:20 +0100 Subject: [PATCH 03/11] DOC Add a changelog entry for 1.3 --- doc/whats_new/v1.3.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index ce579874f886c..9c54d0e9f8d01 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -76,6 +76,32 @@ Changes impacting all modules by :user:`John Pangas `, :user:`Rahil Parikh ` , and :user:`Alex Buzenet `. +- |Feature| The following functions and estimators now support pairs of + sparse CSR matrices and dense datasets: + + - :func:`sklearn.metrics.pairwise_distances_argmin` + - :func:`sklearn.metrics.pairwise_distances_argmin_min` + - :class:`sklearn.cluster.AffinityPropagation` + - :class:`sklearn.cluster.Birch` + - :class:`sklearn.cluster.MeanShift` + - :class:`sklearn.cluster.OPTICS` + - :class:`sklearn.cluster.SpectralClustering` + - :func:`sklearn.feature_selection.mutual_info_regression` + - :class:`sklearn.neighbors.KNeighborsClassifier` + - :class:`sklearn.neighbors.KNeighborsRegressor` + - :class:`sklearn.neighbors.RadiusNeighborsClassifier` + - :class:`sklearn.neighbors.RadiusNeighborsRegressor` + - :class:`sklearn.neighbors.LocalOutlierFactor` + - :class:`sklearn.neighbors.NearestNeighbors` + - :class:`sklearn.manifold.Isomap` + - :class:`sklearn.manifold.LocallyLinearEmbedding` + - :class:`sklearn.manifold.TSNE` + - :func:`sklearn.manifold.trustworthiness` + - :class:`sklearn.semi_supervised.LabelPropagation` + - :class:`sklearn.semi_supervised.LabelSpreading` + + :pr:`25044` by :user:`Julien Jerphanion `. + Changelog --------- From a874def9ced531fe120e65ab70f22bb68f9f3b0b Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 27 Feb 2023 13:44:23 +0100 Subject: [PATCH 04/11] Remove outdated TODO comment This was already studied in: https://github.com/scikit-learn/scikit-learn/pull/25449 Co-authored-by: Vincent M --- .../_pairwise_distances_reduction/_middle_term_computer.pyx.tp | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index 393d4a02c9594..0f30f1e23a432 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -73,9 +73,6 @@ cdef void _middle_term_sparse_sparse_64( {{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} -# TODO: compare this routine with the similar ones in SciPy, especially -# `csr_matvects` which might implement a better algorithm. -# See: https://github.com/scipy/scipy/blob/e58292e066ba2cb2f3d1e0563ca9314ff1f4f311/scipy/sparse/sparsetools/csr.h#L1139-L1175 # noqa cdef void _middle_term_sparse_dense_{{name_suffix}}( const DTYPE_t[:] X_data, const SPARSE_INDEX_TYPE_t[:] X_indices, From 12700d696cee21f02afcec2f0764fac7252a0f16 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Mon, 27 Feb 2023 14:29:40 +0100 Subject: [PATCH 05/11] Add noexcept qualification This is required with Cython>=3.0. --- .../_middle_term_computer.pxd.tp | 6 +++--- .../_middle_term_computer.pyx.tp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp index 35e70a12c9732..6e9a3a0bb7392 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp @@ -203,7 +203,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num - ) nogil + ) noexcept nogil cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks( self, @@ -212,7 +212,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num - ) nogil + ) noexcept nogil cdef DTYPE_t * _compute_dist_middle_terms( self, @@ -221,6 +221,6 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, - ) nogil + ) noexcept nogil {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index 0f30f1e23a432..b94667e7f8207 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -571,7 +571,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, - ) nogil: + ) noexcept nogil: # Flush the thread dist_middle_terms_chunks to 0.0 fill( self.dist_middle_terms_chunks[thread_num].begin(), @@ -586,7 +586,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, - ) nogil: + ) noexcept nogil: # Flush the thread dist_middle_terms_chunks to 0.0 fill( self.dist_middle_terms_chunks[thread_num].begin(), @@ -601,7 +601,7 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_start, ITYPE_t Y_end, ITYPE_t thread_num, - ) nogil: + ) noexcept nogil: cdef: DTYPE_t *dist_middle_terms = ( self.dist_middle_terms_chunks[thread_num].data() From 6563505f56553beb705591eb983b1580c2d13a4f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 28 Feb 2023 14:26:06 +0100 Subject: [PATCH 06/11] TST Completes tests Co-authored-by: Omar Salman --- .../test_pairwise_distances_reduction.py | 68 ++++++------------- 1 file changed, 20 insertions(+), 48 deletions(-) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 7a1df17c9d509..ad0ddbc60e9bd 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,3 +1,4 @@ +import itertools import re import warnings from collections import defaultdict @@ -901,53 +902,24 @@ def test_format_agnosticism( **compute_parameters, ) - dist, indices = Dispatcher.compute( - X_csr, - Y_csr, - parameter, - chunk_size=50, - return_distance=True, - **compute_parameters, - ) - ASSERT_RESULT[(Dispatcher, dtype)]( - dist_dense, - dist, - indices_dense, - indices, - **check_parameters, - ) - - dist, indices = Dispatcher.compute( - X_csr, - Y, - parameter, - chunk_size=50, - return_distance=True, - **compute_parameters, - ) - ASSERT_RESULT[(Dispatcher, dtype)]( - dist_dense, - dist, - indices_dense, - indices, - **check_parameters, - ) - - dist, indices = Dispatcher.compute( - X, - Y_csr, - parameter, - chunk_size=50, - return_distance=True, - **compute_parameters, - ) - ASSERT_RESULT[(Dispatcher, dtype)]( - dist_dense, - dist, - indices_dense, - indices, - **check_parameters, - ) + for _X, _Y in itertools.product((X, X_csr), (Y, Y_csr)): + if _X is X and _Y is Y: + continue + dist, indices = Dispatcher.compute( + _X, + _Y, + parameter, + chunk_size=50, + return_distance=True, + **compute_parameters, + ) + ASSERT_RESULT[(Dispatcher, dtype)]( + dist_dense, + dist, + indices_dense, + indices, + **check_parameters, + ) @pytest.mark.parametrize( @@ -1084,7 +1056,7 @@ def test_pairwise_distances_argkmin( row_idx, argkmin_indices_ref[row_idx] ] - for _X, _Y in [(X, Y), (X_csr, Y_csr)]: + for _X, _Y in itertools.product((X, X_csr), (Y, Y_csr)): argkmin_distances, argkmin_indices = ArgKmin.compute( _X, _Y, From ff6d47f6bea60941879b5668a11c61624c855bea Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 28 Feb 2023 14:26:53 +0100 Subject: [PATCH 07/11] DOC Reword whats_new entry --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 9c54d0e9f8d01..da62daf9aa945 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -77,7 +77,7 @@ Changes impacting all modules and :user:`Alex Buzenet `. - |Feature| The following functions and estimators now support pairs of - sparse CSR matrices and dense datasets: + sparse CSR matrices and dense NumPy arrays for datasets: - :func:`sklearn.metrics.pairwise_distances_argmin` - :func:`sklearn.metrics.pairwise_distances_argmin_min` From 17b6845fddd3766f930f749e352ae1b037cbbf2d Mon Sep 17 00:00:00 2001 From: Julien Date: Wed, 1 Mar 2023 09:57:16 -0500 Subject: [PATCH 08/11] DOC Reword `whats_new` entry --- doc/whats_new/v1.3.rst | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index da62daf9aa945..04c9ed1c99587 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -76,8 +76,11 @@ Changes impacting all modules by :user:`John Pangas `, :user:`Rahil Parikh ` , and :user:`Alex Buzenet `. -- |Feature| The following functions and estimators now support pairs of - sparse CSR matrices and dense NumPy arrays for datasets: +- |Enhancement| A multi-threaded routine for pair of datasets consisting + of a sparse CSR matrix and a dense NumPy array has been introduced. + + The following estimators now either support this new case or have + better performance for this case: - :func:`sklearn.metrics.pairwise_distances_argmin` - :func:`sklearn.metrics.pairwise_distances_argmin_min` @@ -100,6 +103,9 @@ Changes impacting all modules - :class:`sklearn.semi_supervised.LabelPropagation` - :class:`sklearn.semi_supervised.LabelSpreading` + For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is up to 2 + times faster for this case now on laptops. + :pr:`25044` by :user:`Julien Jerphanion `. Changelog From 22d2cd2d266c0be380ed67f5d7dd011552bfe051 Mon Sep 17 00:00:00 2001 From: Julien Date: Wed, 1 Mar 2023 10:44:24 -0500 Subject: [PATCH 09/11] DOC Better word the whats_new entry Co-authored-by: Olivier Grisel --- doc/whats_new/v1.3.rst | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 04c9ed1c99587..9a2832ed69c38 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -76,12 +76,12 @@ Changes impacting all modules by :user:`John Pangas `, :user:`Rahil Parikh ` , and :user:`Alex Buzenet `. -- |Enhancement| A multi-threaded routine for pair of datasets consisting - of a sparse CSR matrix and a dense NumPy array has been introduced. - - The following estimators now either support this new case or have - better performance for this case: +- |Enhancement| Added a multi-threaded Cython routine to the compute squared + Euclidean distances (sometimes followed by a fused reduction operation) for a + pair of datasets consisting of a sparse CSR matrix and a dense NumPy. + This can improve the performance following functions and estimators. + - :func:`sklearn.metrics.pairwise_distances_argmin` - :func:`sklearn.metrics.pairwise_distances_argmin_min` - :class:`sklearn.cluster.AffinityPropagation` @@ -103,8 +103,12 @@ Changes impacting all modules - :class:`sklearn.semi_supervised.LabelPropagation` - :class:`sklearn.semi_supervised.LabelSpreading` - For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is up to 2 - times faster for this case now on laptops. + A typical example of this performance improvement happens when passing a sparse + CSR matrix to the `predict` or `transform` method of estimators that rely on + a dense numpy representation to store their fitted parameters (or the reverse). + + For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is up to 2 times faster + for this case now on commonly available laptops. :pr:`25044` by :user:`Julien Jerphanion `. From 45d425a6a90c2b5e2c8e342df0a00e6b5eb12132 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 9 Mar 2023 10:28:52 +0100 Subject: [PATCH 10/11] DOC Better document implementation Signed-off-by: Julien Jerphanion --- doc/whats_new/v1.3.rst | 6 +-- .../_middle_term_computer.pxd.tp | 4 ++ .../_middle_term_computer.pyx.tp | 38 +++++++++++++------ 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 9a2832ed69c38..8e5137807ea01 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -80,8 +80,8 @@ Changes impacting all modules Euclidean distances (sometimes followed by a fused reduction operation) for a pair of datasets consisting of a sparse CSR matrix and a dense NumPy. - This can improve the performance following functions and estimators. - + This can improve the performance of following functions and estimators: + - :func:`sklearn.metrics.pairwise_distances_argmin` - :func:`sklearn.metrics.pairwise_distances_argmin_min` - :class:`sklearn.cluster.AffinityPropagation` @@ -105,7 +105,7 @@ Changes impacting all modules A typical example of this performance improvement happens when passing a sparse CSR matrix to the `predict` or `transform` method of estimators that rely on - a dense numpy representation to store their fitted parameters (or the reverse). + a dense NumPy representation to store their fitted parameters (or the reverse). For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is up to 2 times faster for this case now on commonly available laptops. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp index 6e9a3a0bb7392..6b116f0f44d6f 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd.tp @@ -194,6 +194,10 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name const {{INPUT_DTYPE_t}}[:, ::1] Y + # We treat the dense-sparse case with the sparse-dense case by simply + # treating the dist_middle_terms as F-ordered and by swapping arguments. + # This attribute is meant to encode the case and adapt the logic + # accordingly. bint c_ordered_middle_term cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks( diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index b94667e7f8207..1c5776a1f2d3d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -83,11 +83,11 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}( ITYPE_t Y_start, ITYPE_t Y_end, bint c_ordered_middle_term, - DTYPE_t * D, + DTYPE_t * dist_middle_terms, ) nogil: - # This routine assumes that D points to the first element of a - # zeroed buffer of length at least equal to n_X × n_Y, conceptually - # representing a 2-d C-ordered array. + # This routine assumes that dist_middle_terms is a pointer to the first element + # of a zeroed buffer of length at least equal to n_X × n_Y, conceptually + # representing a 2-d C-ordered of F-ordered array. cdef: ITYPE_t i, j, k ITYPE_t n_X = X_end - X_start @@ -99,7 +99,7 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}( k = i * n_Y + j if c_ordered_middle_term else j * n_X + i for X_i_ptr in range(X_indptr[X_start+i], X_indptr[X_start+i+1]): X_i_col_idx = X_indices[X_i_ptr] - D[k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx] + dist_middle_terms[k] += -2 * X_data[X_i_ptr] * Y[Y_start + j, X_i_col_idx] cdef class MiddleTermComputer{{name_suffix}}: @@ -183,7 +183,14 @@ cdef class MiddleTermComputer{{name_suffix}}: c_ordered_middle_term=True ) if not X_is_sparse and Y_is_sparse: + # NOTE: The Dense-Sparse case is implement via the Sparse-Dense case. + # + # To do so: + # - X (dense) and Y (sparse) are swapped + # - the distance middle term is seen as F-ordered for consistency + # (c_ordered_middle_term = False) return SparseDenseMiddleTermComputer{{name_suffix}}( + # Mind that X and Y are swapped here. Y, X, effective_n_threads, @@ -572,7 +579,8 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_end, ITYPE_t thread_num, ) noexcept nogil: - # Flush the thread dist_middle_terms_chunks to 0.0 + # Fill the thread's dist_middle_terms_chunks with 0.0 before + # computing its elements in _compute_dist_middle_terms. fill( self.dist_middle_terms_chunks[thread_num].begin(), self.dist_middle_terms_chunks[thread_num].end(), @@ -587,7 +595,8 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name ITYPE_t Y_end, ITYPE_t thread_num, ) noexcept nogil: - # Flush the thread dist_middle_terms_chunks to 0.0 + # Fill the thread's dist_middle_terms_chunks with 0.0 before + # computing its elements in _compute_dist_middle_terms. fill( self.dist_middle_terms_chunks[thread_num].begin(), self.dist_middle_terms_chunks[thread_num].end(), @@ -607,15 +616,22 @@ cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name self.dist_middle_terms_chunks[thread_num].data() ) + # For the dense-sparse case, we use the sparse-dense case + # with dist_middle_terms seen as F-ordered. + # Hence we swap indices pointers here. + if not self.c_ordered_middle_term: + X_start, Y_start = Y_start, X_start + X_end, Y_end = Y_end, X_end + _middle_term_sparse_dense_{{name_suffix}}( self.X_data, self.X_indices, self.X_indptr, - X_start if self.c_ordered_middle_term else Y_start, - X_end if self.c_ordered_middle_term else Y_end, + X_start, + X_end, self.Y, - Y_start if self.c_ordered_middle_term else X_start, - Y_end if self.c_ordered_middle_term else X_end, + Y_start, + Y_end, self.c_ordered_middle_term, dist_middle_terms, ) From b95ca39d7060d86bc86bf44b0f3a8466d819d630 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Fri, 10 Mar 2023 06:05:55 -0500 Subject: [PATCH 11/11] DOC Better word comment and changelog entry Co-authored-by: Christian Lorentzen --- doc/whats_new/v1.3.rst | 4 ++-- .../_middle_term_computer.pyx.tp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 8e5137807ea01..c3f8de1e55408 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -107,8 +107,8 @@ Changes impacting all modules CSR matrix to the `predict` or `transform` method of estimators that rely on a dense NumPy representation to store their fitted parameters (or the reverse). - For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is up to 2 times faster - for this case now on commonly available laptops. + For instance, :meth:`sklearn.NearestNeighbors.kneighbors` is now up to 2 times faster + for this case on commonly available laptops. :pr:`25044` by :user:`Julien Jerphanion `. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp index 1c5776a1f2d3d..255efc83565d5 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx.tp @@ -86,7 +86,7 @@ cdef void _middle_term_sparse_dense_{{name_suffix}}( DTYPE_t * dist_middle_terms, ) nogil: # This routine assumes that dist_middle_terms is a pointer to the first element - # of a zeroed buffer of length at least equal to n_X × n_Y, conceptually + # of a buffer filled with zeros of length at least equal to n_X × n_Y, conceptually # representing a 2-d C-ordered of F-ordered array. cdef: ITYPE_t i, j, k