@@ -111,6 +111,18 @@ def assert_no_missing_neighbors(
111
111
indices_row_b ,
112
112
threshold ,
113
113
):
114
+ """Compare the indices of neighbors in two results sets.
115
+
116
+ Any neighbor index with a distance below the precision threshold should
117
+ match one in the other result set. We ignore the last few neighbors beyond
118
+ the threshold as those can typically be missing due to rounding errors.
119
+
120
+ For radius queries, the threshold is just the radius minus the expected
121
+ precision level.
122
+
123
+ For k-NN queries, it is the maxium distance to the k-th neighbor minus the
124
+ expected precision level.
125
+ """
114
126
mask_a = dist_row_a < threshold
115
127
mask_b = dist_row_b < threshold
116
128
missing_from_b = np .setdiff1d (indices_row_a [mask_a ], indices_row_b )
@@ -179,8 +191,15 @@ def assert_compatible_argkmin_results(
179
191
atol ,
180
192
)
181
193
182
- # Check that any neighbor with distances below the rounding error threshold have
183
- # matching indices.
194
+ # Check that any neighbor with distances below the rounding error
195
+ # threshold have matching indices. The threshold is the distance to the
196
+ # k-th neighbors minus the expected precision level:
197
+ #
198
+ # (1 - rtol) * dist_k - atol
199
+ #
200
+ # Where dist_k is defined as the maxium distance to the kth-neighbor
201
+ # among the two result sets. This way of defining the threshold is
202
+ # stricter than taking the minimum of the two.
184
203
threshold = (1 - rtol ) * np .maximum (
185
204
np .max (dist_row_a ), np .max (dist_row_b )
186
205
) - atol
0 commit comments