@@ -166,14 +166,20 @@ def test_radius_neighbors():
166
166
lshf .fit (X )
167
167
168
168
for i in range (n_iter ):
169
+ # Select a random point in the dataset as the query
169
170
query = X [rng .randint (0 , n_samples )]
171
+
172
+ # At least one neighbor should be returned when the radius is the
173
+ # mean distance from the query to the points of the dataset.
170
174
mean_dist = np .mean (pairwise_distances (query , X , metric = 'cosine' ))
171
175
neighbors = lshf .radius_neighbors (query , radius = mean_dist ,
172
176
return_distance = False )
177
+
173
178
assert_equal (neighbors .shape , (1 ,))
174
179
assert_equal (neighbors .dtype , object )
175
180
assert_greater (neighbors [0 ].shape [0 ], 0 )
176
- # All distances should be less than mean_dist
181
+ # All distances to points in the results of the radius query should
182
+ # be less than mean_dist
177
183
distances , neighbors = lshf .radius_neighbors (query ,
178
184
radius = mean_dist ,
179
185
return_distance = True )
@@ -184,23 +190,33 @@ def test_radius_neighbors():
184
190
queries = X [rng .randint (0 , n_samples , n_queries )]
185
191
distances , neighbors = lshf .radius_neighbors (queries ,
186
192
return_distance = True )
187
- assert_equal (neighbors .shape [0 ], n_queries )
188
- assert_equal (distances .shape [0 ], n_queries )
189
- # dists and inds should not be 2D arrays
190
- assert_equal (distances .ndim , 1 )
191
- assert_equal (neighbors .ndim , 1 )
193
+
194
+ # dists and inds should not be 1D arrays or arrays of variable lengths
195
+ # hence the use of the object dtype.
196
+ assert_equal (distances .shape , (n_queries ,))
197
+ assert_equal (distances .dtype , object )
198
+ assert_equal (neighbors .shape , (n_queries ,))
199
+ assert_equal (neighbors .dtype , object )
192
200
193
201
# Compare with exact neighbor search
194
202
query = X [rng .randint (0 , n_samples )]
195
203
mean_dist = np .mean (pairwise_distances (query , X , metric = 'cosine' ))
196
- nbrs = NearestNeighbors (algorithm = 'brute' , metric = 'cosine' )
197
- nbrs .fit (X )
204
+ nbrs = NearestNeighbors (algorithm = 'brute' , metric = 'cosine' ).fit (X )
198
205
199
- distances_approx , _ = lshf .radius_neighbors (query , radius = mean_dist )
200
206
distances_exact , _ = nbrs .radius_neighbors (query , radius = mean_dist )
201
- # Distances of exact neighbors is less than or equal to approximate
202
- assert_true (np .all (np .less_equal (np .sort (distances_exact [0 ]),
203
- np .sort (distances_approx [0 ]))))
207
+ distances_approx , _ = lshf .radius_neighbors (query , radius = mean_dist )
208
+
209
+ # Radius-based queries do not sort the result points and the order
210
+ # depends on the method, the random_state and the dataset order. Therefore
211
+ # we need to sort the results ourselves before performing any comparison.
212
+ sorted_dists_exact = np .sort (distances_exact [0 ])
213
+ sorted_dists_approx = np .sort (distances_approx [0 ])
214
+
215
+ # Distances to exact neighbors are less than or equal to approximate
216
+ # counterparts as the approximate radius query might have missed some
217
+ # closer neighbors.
218
+ assert_true (np .all (np .less_equal (sorted_dists_exact ,
219
+ sorted_dists_approx )))
204
220
205
221
206
222
def test_distances ():
@@ -220,15 +236,13 @@ def test_distances():
220
236
distances , neighbors = lshf .kneighbors (query ,
221
237
n_neighbors = n_neighbors ,
222
238
return_distance = True )
223
- # Returned neighbors should be from closest to farthest.
224
- assert_true (np .all (np .diff (distances [0 ]) >= 0 ))
225
239
226
- mean_dist = np .mean (pairwise_distances (query , X , metric = 'cosine' ))
227
- distances , neighbors = lshf .radius_neighbors (query ,
228
- radius = mean_dist ,
229
- return_distance = True )
240
+ # Returned neighbors should be from closest to farthest, that is
241
+ # increasing distance values.
230
242
assert_true (np .all (np .diff (distances [0 ]) >= 0 ))
231
243
244
+ # The radius_neighbors method does guarantee the order of the results.
245
+
232
246
233
247
def test_fit ():
234
248
"""Checks whether `fit` method sets all attribute values correctly."""
@@ -407,8 +421,3 @@ def test_sparse_input():
407
421
assert_array_equal (a , b )
408
422
for a , b in zip (i_sparse , i_dense ):
409
423
assert_array_equal (a , b )
410
-
411
-
412
- if __name__ == "__main__" :
413
- import nose
414
- nose .runmodule ()
0 commit comments