|
41 | 41 | neighbors.radius_neighbors_graph = ignore_warnings(
|
42 | 42 | neighbors.radius_neighbors_graph)
|
43 | 43 |
|
| 44 | + |
44 | 45 | def _weight_func(dist):
|
45 | 46 | """ Weight function to replace lambda d: d ** -2.
|
46 | 47 | The lambda function is not valid because:
|
@@ -351,7 +352,11 @@ def test_neighbors_regressors_zero_distance():
|
351 | 352 |
|
352 | 353 |
|
353 | 354 | def test_radius_neighbors_boundary_handling():
|
354 |
| - """Test whether points lying on boundary are handled consistently""" |
| 355 | + """Test whether points lying on boundary are handled consistently |
| 356 | +
|
| 357 | + Also ensures that even with only one query point, an object array |
| 358 | + is returned rather than a 2d array. |
| 359 | + """ |
355 | 360 |
|
356 | 361 | X = np.array([[1.5], [3.0], [3.01]])
|
357 | 362 | radius = 3.0
|
@@ -941,7 +946,8 @@ def test_non_euclidean_kneighbors():
|
941 | 946 | nbrs_graph = neighbors.radius_neighbors_graph(
|
942 | 947 | X, radius, metric=metric).toarray()
|
943 | 948 | nbrs1 = neighbors.NearestNeighbors(metric=metric, radius=radius).fit(X)
|
944 |
| - assert_array_equal(nbrs_graph, nbrs1.radius_neighbors_graph(X).toarray()) |
| 949 | + assert_array_equal(nbrs_graph, |
| 950 | + nbrs1.radius_neighbors_graph(X).toarray()) |
945 | 951 |
|
946 | 952 | # Raise error when wrong parameters are supplied,
|
947 | 953 | X_nbrs = neighbors.NearestNeighbors(3, metric='manhattan')
|
@@ -1042,7 +1048,8 @@ def test_k_and_radius_neighbors_duplicates():
|
1042 | 1048 | rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5)
|
1043 | 1049 | assert_array_equal(rng.A, np.ones((2, 2)))
|
1044 | 1050 |
|
1045 |
| - rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5, mode='distance') |
| 1051 | + rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5, |
| 1052 | + mode='distance') |
1046 | 1053 | assert_array_equal(rng.A, [[0, 1], [1, 0]])
|
1047 | 1054 | assert_array_equal(rng.indices, [0, 1, 0, 1])
|
1048 | 1055 | assert_array_equal(rng.data, [0, 1, 1, 0])
|
|
0 commit comments