diff --git a/sklearn/metrics/_dist_metrics.pyx.tp b/sklearn/metrics/_dist_metrics.pyx.tp index 6b5ea300f038b..eb9b1a46c4d1f 100644 --- a/sklearn/metrics/_dist_metrics.pyx.tp +++ b/sklearn/metrics/_dist_metrics.pyx.tp @@ -241,6 +241,7 @@ METRIC_MAPPING{{name_suffix}} = { 'jaccard': JaccardDistance{{name_suffix}}, 'dice': DiceDistance{{name_suffix}}, 'kulsinski': KulsinskiDistance{{name_suffix}}, + 'precomputed': PrecomputedDistanceMatrix{{name_suffix}}, 'rogerstanimoto': RogersTanimotoDistance{{name_suffix}}, 'russellrao': RussellRaoDistance{{name_suffix}}, 'sokalmichener': SokalMichenerDistance{{name_suffix}}, @@ -348,6 +349,18 @@ cdef class DistanceMetric{{name_suffix}}(DistanceMetric): "sokalsneath" SokalSneathDistance NNEQ / (NNEQ + 0.5 * NTT) ================= ======================= =============================== + **Metrics with precomputed distances:** Any user can compute a distance + matrix and provide access to the distances, neighbors and other + data defined in this interface. The precomputed distance matrix should be + (n_samples_X, n_samples_Y) shape fulfilling the properties of a valid distance + metric. + + ================= ========================= =============================== + identifier class name distance function + ----------------- ------------------------- ------------------------------- + "precomputed" PrecomputedDistanceMatrix predefined + ================= ========================= =============================== + **User-defined distance:** =========== =============== ======= diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp index 1e57b3291a8f4..fe91bf088a8c5 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd.tp @@ -29,6 +29,11 @@ cdef class DatasetsPair{{name_suffix}}: cdef float64_t surrogate_dist(self, intp_t i, intp_t j) noexcept nogil +cdef class PrecomputedDistanceMatrix{{name_suffix}}(DatasetsPair{{name_suffix}}): + cdef: + const {{INPUT_DTYPE_t}}[:, ::1] distance_matrix + + cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): cdef: const {{INPUT_DTYPE_t}}[:, ::1] X diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 2c3ca44047145..d0506428c1a8d 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -98,6 +98,12 @@ cdef class DatasetsPair{{name_suffix}}: metric_kwargs = copy.copy(metric_kwargs) metric_kwargs.pop("X_norm_squared", None) metric_kwargs.pop("Y_norm_squared", None) + + if metric == 'precomputed': + return PrecomputedDistanceMatrix{{name_suffix}}( + distance_matrix=Y, + ) + cdef: {{DistanceMetric}} distance_metric = DistanceMetric.get_metric( metric, @@ -158,6 +164,47 @@ cdef class DatasetsPair{{name_suffix}}: # TODO: add "with gil: raise" here when supporting Cython 3.0 return -1 + +@final +cdef class PrecomputedDistanceMatrix{{name_suffix}}(DatasetsPair{{name_suffix}}): + """A precomputed distance matrix between row vectors of two arrays. + + Parameters + ---------- + distance_matrix: ndarray of shape (n_samples_X, n_samples_Y) + Rows represent vectors. Must be C-contiguous. + """ + + def __init__( + self, + const {{INPUT_DTYPE_t}}[:, ::1] distance_matrix, + ): + super().__init__( + # This DistanceMetric is necessary for conversion between + # reduced distance and distance (it performs no-ops). + distance_metric={{DistanceMetric}}(), + n_features=0, + ) + # Arrays have already been checked + self.distance_matrix = distance_matrix + + @final + cdef ITYPE_t n_samples_X(self) noexcept nogil: + return self.distance_matrix.shape[0] + + @final + cdef ITYPE_t n_samples_Y(self) noexcept nogil: + return self.distance_matrix.shape[1] + + @final + cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) noexcept nogil: + return self.distance_matrix[i, j] + + @final + cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) noexcept nogil: + return self.distance_matrix[i, j] + + @final cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}): """Compute distances between row vectors of two arrays. diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 956de3577bcee..501ba58416d17 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -78,7 +78,7 @@ def valid_metrics(cls) -> List[str]: "hamming", *BOOL_METRICS, } - return sorted(({"sqeuclidean"} | set(METRIC_MAPPING64.keys())) - excluded) + return sorted(({"sqeuclidean", "precomputed"} | set(METRIC_MAPPING64.keys())) - excluded) @classmethod def is_usable_for(cls, X, Y, metric) -> bool: