diff --git a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp index 8899f49330440..2c3ca44047145 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx.tp @@ -1,3 +1,5 @@ +import copy + {{py: implementation_specific_values = [ @@ -84,12 +86,17 @@ cdef class DatasetsPair{{name_suffix}}: datasets_pair: DatasetsPair{{name_suffix}} The suited DatasetsPair{{name_suffix}} implementation. """ - # Y_norm_squared might be propagated down to DatasetsPairs - # via metrics_kwargs when the Euclidean specialisations - # can't be used. To prevent Y_norm_squared to be passed + # X_norm_squared and Y_norm_squared might be propagated + # down to DatasetsPairs via metrics_kwargs when the Euclidean + # specialisations can't be used. + # To prevent X_norm_squared and Y_norm_squared to be passed # down to DistanceMetrics (whose constructors would raise - # a RuntimeError), we pop it here. + # a RuntimeError), we pop them here. if metric_kwargs is not None: + # Copying metric_kwargs not to pop "X_norm_squared" + # and "Y_norm_squared" where they are used + metric_kwargs = copy.copy(metric_kwargs) + metric_kwargs.pop("X_norm_squared", None) metric_kwargs.pop("Y_norm_squared", None) cdef: {{DistanceMetric}} distance_metric = DistanceMetric.get_metric(