Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dd5c2e7

Browse files
committed
TST ensure return type of radius_neighbors is object
and Pep8
1 parent e01a16d commit dd5c2e7

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

sklearn/neighbors/tests/test_approximate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ def test_radius_neighbors():
170170
mean_dist = np.mean(pairwise_distances(query, X, metric='cosine'))
171171
neighbors = lshf.radius_neighbors(query, radius=mean_dist,
172172
return_distance=False)
173-
# At least one neighbor should be returned.
174-
assert_greater(neighbors.shape[0], 0)
173+
assert_equal(neighbors.shape, (1,))
174+
assert_equal(neighbors.dtype, object)
175+
assert_greater(neighbors[0].shape[0], 0)
175176
# All distances should be less than mean_dist
176177
distances, neighbors = lshf.radius_neighbors(query,
177178
radius=mean_dist,

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
neighbors.radius_neighbors_graph = ignore_warnings(
4242
neighbors.radius_neighbors_graph)
4343

44+
4445
def _weight_func(dist):
4546
""" Weight function to replace lambda d: d ** -2.
4647
The lambda function is not valid because:
@@ -351,7 +352,11 @@ def test_neighbors_regressors_zero_distance():
351352

352353

353354
def test_radius_neighbors_boundary_handling():
354-
"""Test whether points lying on boundary are handled consistently"""
355+
"""Test whether points lying on boundary are handled consistently
356+
357+
Also ensures that even with only one query point, an object array
358+
is returned rather than a 2d array.
359+
"""
355360

356361
X = np.array([[1.5], [3.0], [3.01]])
357362
radius = 3.0
@@ -941,7 +946,8 @@ def test_non_euclidean_kneighbors():
941946
nbrs_graph = neighbors.radius_neighbors_graph(
942947
X, radius, metric=metric).toarray()
943948
nbrs1 = neighbors.NearestNeighbors(metric=metric, radius=radius).fit(X)
944-
assert_array_equal(nbrs_graph, nbrs1.radius_neighbors_graph(X).toarray())
949+
assert_array_equal(nbrs_graph,
950+
nbrs1.radius_neighbors_graph(X).toarray())
945951

946952
# Raise error when wrong parameters are supplied,
947953
X_nbrs = neighbors.NearestNeighbors(3, metric='manhattan')
@@ -1042,7 +1048,8 @@ def test_k_and_radius_neighbors_duplicates():
10421048
rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5)
10431049
assert_array_equal(rng.A, np.ones((2, 2)))
10441050

1045-
rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5, mode='distance')
1051+
rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5,
1052+
mode='distance')
10461053
assert_array_equal(rng.A, [[0, 1], [1, 0]])
10471054
assert_array_equal(rng.indices, [0, 1, 0, 1])
10481055
assert_array_equal(rng.data, [0, 1, 1, 0])

0 commit comments

Comments
 (0)