diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index fbd2c5852b62c..f67df8333fcec 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -162,11 +162,11 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): for i in range(n_samples_X): for j in range(n_samples_Y): heap_push( - heaps_r_distances + i * self.k, - heaps_indices + i * self.k, - self.k, - self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), - Y_start + j, + values=heaps_r_distances + i * self.k, + indices=heaps_indices + i * self.k, + size=self.k, + val=self.datasets_pair.surrogate_dist(X_start + i, Y_start + j), + val_idx=Y_start + j, ) cdef void _parallel_on_X_init_chunk( @@ -255,11 +255,11 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): for thread_num in range(self.chunks_n_threads): for jdx in range(self.k): heap_push( - &self.argkmin_distances[X_start + idx, 0], - &self.argkmin_indices[X_start + idx, 0], - self.k, - self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], - self.heaps_indices_chunks[thread_num][idx * self.k + jdx], + values=&self.argkmin_distances[X_start + idx, 0], + indices=&self.argkmin_indices[X_start + idx, 0], + size=self.k, + val=self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx], + val_idx=self.heaps_indices_chunks[thread_num][idx * self.k + jdx], ) cdef void _parallel_on_Y_finalize( @@ -292,7 +292,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): num_threads=self.effective_n_threads): for j in range(self.k): distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. + # Guard against potential -0., causing nan production. max(distances[i, j], 0.) ) @@ -304,7 +304,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}): # Values are returned identically to the way `KNeighborsMixin.kneighbors` # returns values. This is counter-intuitive but this allows not using - # complex adaptations where `ArgKmin{{name_suffix}}.compute` is called. + # complex adaptations where `ArgKmin.compute` is called. return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices) return np.asarray(self.argkmin_indices) @@ -330,8 +330,10 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): ): if ( metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs + len(metric_kwargs) > 0 and ( + "Y_norm_squared" not in metric_kwargs or + "X_norm_squared" not in metric_kwargs + ) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " @@ -365,20 +367,31 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): metric_kwargs.pop("Y_norm_squared"), ensure_2d=False, input_name="Y_norm_squared", - dtype=np.float64 + dtype=np.float64, ) else: self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( - Y, self.effective_n_threads + Y, + self.effective_n_threads, ) - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms{{name_suffix}}( - X, self.effective_n_threads + if metric_kwargs is not None and "X_norm_squared" in metric_kwargs: + self.X_norm_squared = check_array( + metric_kwargs.pop("X_norm_squared"), + ensure_2d=False, + input_name="X_norm_squared", + dtype=np.float64, ) - ) + else: + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms{{name_suffix}}( + X, + self.effective_n_threads, + ) + ) + self.use_squared_distances = use_squared_distances @final @@ -470,6 +483,7 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): ) nogil: cdef: ITYPE_t i, j + DTYPE_t sqeuclidean_dist_i_j ITYPE_t n_X = X_end - X_start ITYPE_t n_Y = Y_end - Y_start DTYPE_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms( @@ -483,20 +497,22 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): # which keep tracks of the argkmin. for i in range(n_X): for j in range(n_Y): + sqeuclidean_dist_i_j = ( + self.X_norm_squared[i + X_start] + + dist_middle_terms[i * n_Y + j] + + self.Y_norm_squared[j + Y_start] + ) + + # Catastrophic cancellation might cause -0. to be present, + # e.g. when computing d(x_i, y_i) when X is Y. + sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j) + heap_push( - heaps_r_distances + i * self.k, - heaps_indices + i * self.k, - self.k, - # Using the squared euclidean distance as the rank-preserving distance: - # - # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - ( - self.X_norm_squared[i + X_start] + - dist_middle_terms[i * n_Y + j] + - self.Y_norm_squared[j + Y_start] - ), - j + Y_start, + values=heaps_r_distances + i * self.k, + indices=heaps_indices + i * self.k, + size=self.k, + val=sqeuclidean_dist_i_j, + val_idx=j + Y_start, ) {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index 26d25e1040003..aec943448be3f 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -311,7 +311,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}) for j in range(deref(self.neigh_indices)[i].size()): deref(self.neigh_distances)[i][j] = ( self.datasets_pair.distance_metric._rdist_to_dist( - # Guard against eventual -0., causing nan production. + # Guard against potential -0., causing nan production. max(deref(self.neigh_distances)[i][j], 0.) ) ) @@ -338,8 +338,10 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} ): if ( metric_kwargs is not None and - len(metric_kwargs) > 0 and - "Y_norm_squared" not in metric_kwargs + len(metric_kwargs) > 0 and ( + "Y_norm_squared" not in metric_kwargs or + "X_norm_squared" not in metric_kwargs + ) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " @@ -374,16 +376,31 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} metric_kwargs.pop("Y_norm_squared"), ensure_2d=False, input_name="Y_norm_squared", - dtype=np.float64 + dtype=np.float64, ) else: - self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(Y, self.effective_n_threads) + self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}( + Y, + self.effective_n_threads, + ) + + if metric_kwargs is not None and "X_norm_squared" in metric_kwargs: + self.X_norm_squared = check_array( + metric_kwargs.pop("X_norm_squared"), + ensure_2d=False, + input_name="X_norm_squared", + dtype=np.float64, + ) + else: + # Do not recompute norms if datasets are identical. + self.X_norm_squared = ( + self.Y_norm_squared if X is Y else + _sqeuclidean_row_norms{{name_suffix}}( + X, + self.effective_n_threads, + ) + ) - # Do not recompute norms if datasets are identical. - self.X_norm_squared = ( - self.Y_norm_squared if X is Y else - _sqeuclidean_row_norms{{name_suffix}}(X, self.effective_n_threads) - ) self.use_squared_distances = use_squared_distances if use_squared_distances: @@ -480,7 +497,7 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} ) nogil: cdef: ITYPE_t i, j - DTYPE_t squared_dist_i_j + DTYPE_t sqeuclidean_dist_i_j ITYPE_t n_X = X_end - X_start ITYPE_t n_Y = Y_end - Y_start DTYPE_t *dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms( @@ -490,17 +507,18 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} # Pushing the distance and their associated indices in vectors. for i in range(n_X): for j in range(n_Y): - # Using the squared euclidean distance as the rank-preserving distance: - # - # ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||² - # - squared_dist_i_j = ( + sqeuclidean_dist_i_j = ( self.X_norm_squared[i + X_start] + dist_middle_terms[i * n_Y + j] + self.Y_norm_squared[j + Y_start] ) - if squared_dist_i_j <= self.r_radius: - deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(squared_dist_i_j) + + # Catastrophic cancellation might cause -0. to be present, + # e.g. when computing d(x_i, y_i) when X is Y. + sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j) + + if sqeuclidean_dist_i_j <= self.r_radius: + deref(self.neigh_distances_chunks[thread_num])[i + X_start].push_back(sqeuclidean_dist_i_j) deref(self.neigh_indices_chunks[thread_num])[i + X_start].push_back(j + Y_start) {{endfor}}