Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4872503

Browse files
authored
TST use global_random_seed in sklearn/feature_extraction/tests/test_image.py (scikit-learn#31310)
1 parent 008d47a commit 4872503

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

sklearn/feature_extraction/tests/test_image.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,23 @@ def test_reconstruct_patches_perfect_color(orange_face):
223223
np.testing.assert_array_almost_equal(face, face_reconstructed)
224224

225225

226-
def test_patch_extractor_fit(downsampled_face_collection):
226+
def test_patch_extractor_fit(downsampled_face_collection, global_random_seed):
227227
faces = downsampled_face_collection
228-
extr = PatchExtractor(patch_size=(8, 8), max_patches=100, random_state=0)
228+
extr = PatchExtractor(
229+
patch_size=(8, 8), max_patches=100, random_state=global_random_seed
230+
)
229231
assert extr == extr.fit(faces)
230232

231233

232-
def test_patch_extractor_max_patches(downsampled_face_collection):
234+
def test_patch_extractor_max_patches(downsampled_face_collection, global_random_seed):
233235
faces = downsampled_face_collection
234236
i_h, i_w = faces.shape[1:3]
235237
p_h, p_w = 8, 8
236238

237239
max_patches = 100
238240
expected_n_patches = len(faces) * max_patches
239241
extr = PatchExtractor(
240-
patch_size=(p_h, p_w), max_patches=max_patches, random_state=0
242+
patch_size=(p_h, p_w), max_patches=max_patches, random_state=global_random_seed
241243
)
242244
patches = extr.transform(faces)
243245
assert patches.shape == (expected_n_patches, p_h, p_w)
@@ -247,35 +249,37 @@ def test_patch_extractor_max_patches(downsampled_face_collection):
247249
(i_h - p_h + 1) * (i_w - p_w + 1) * max_patches
248250
)
249251
extr = PatchExtractor(
250-
patch_size=(p_h, p_w), max_patches=max_patches, random_state=0
252+
patch_size=(p_h, p_w), max_patches=max_patches, random_state=global_random_seed
251253
)
252254
patches = extr.transform(faces)
253255
assert patches.shape == (expected_n_patches, p_h, p_w)
254256

255257

256-
def test_patch_extractor_max_patches_default(downsampled_face_collection):
258+
def test_patch_extractor_max_patches_default(
259+
downsampled_face_collection, global_random_seed
260+
):
257261
faces = downsampled_face_collection
258-
extr = PatchExtractor(max_patches=100, random_state=0)
262+
extr = PatchExtractor(max_patches=100, random_state=global_random_seed)
259263
patches = extr.transform(faces)
260264
assert patches.shape == (len(faces) * 100, 19, 25)
261265

262266

263-
def test_patch_extractor_all_patches(downsampled_face_collection):
267+
def test_patch_extractor_all_patches(downsampled_face_collection, global_random_seed):
264268
faces = downsampled_face_collection
265269
i_h, i_w = faces.shape[1:3]
266270
p_h, p_w = 8, 8
267271
expected_n_patches = len(faces) * (i_h - p_h + 1) * (i_w - p_w + 1)
268-
extr = PatchExtractor(patch_size=(p_h, p_w), random_state=0)
272+
extr = PatchExtractor(patch_size=(p_h, p_w), random_state=global_random_seed)
269273
patches = extr.transform(faces)
270274
assert patches.shape == (expected_n_patches, p_h, p_w)
271275

272276

273-
def test_patch_extractor_color(orange_face):
277+
def test_patch_extractor_color(orange_face, global_random_seed):
274278
faces = _make_images(orange_face)
275279
i_h, i_w = faces.shape[1:3]
276280
p_h, p_w = 8, 8
277281
expected_n_patches = len(faces) * (i_h - p_h + 1) * (i_w - p_w + 1)
278-
extr = PatchExtractor(patch_size=(p_h, p_w), random_state=0)
282+
extr = PatchExtractor(patch_size=(p_h, p_w), random_state=global_random_seed)
279283
patches = extr.transform(faces)
280284
assert patches.shape == (expected_n_patches, p_h, p_w, 3)
281285

0 commit comments

Comments
 (0)