Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d1f4760

Browse files
authored
[BE] Small cleanup in get_files_to_run (#1923)
Make `get_all_files` and `calculate_shards` work regardless of the script invocation cwd Add `test_files_to_run.py` Do not leak filedescriptor while reading metadata.json Use list comprehension instead of list(map(lambda
1 parent 458285f commit d1f4760

2 files changed

Lines changed: 31 additions & 9 deletions

File tree

.jenkins/get_files_to_run.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from remove_runnable_code import remove_runnable_code
77

88

9+
# Calculate repo base dir
10+
REPO_BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11+
912
def get_all_files(encoding="utf-8") -> List[str]:
1013
sources = [
1114
"beginner_source",
@@ -16,11 +19,12 @@ def get_all_files(encoding="utf-8") -> List[str]:
1619
]
1720
cmd = ["find"] + sources + ["-name", "*.py", "-not", "-path", "*/data/*"]
1821

19-
return run(cmd, capture_output=True).stdout.decode(encoding).splitlines()
22+
return run(cmd, capture_output=True, cwd=REPO_BASE_DIR).stdout.decode(encoding).splitlines()
2023

2124

2225
def calculate_shards(all_files, num_shards=20):
23-
metadata = json.load(open(".jenkins/metadata.json"))
26+
with open(os.path.join(REPO_BASE_DIR, ".jenkins", "metadata.json")) as fp:
27+
metadata = json.load(fp)
2428
sharded_files = [(0.0, []) for _ in range(num_shards)]
2529

2630
def get_duration(file):
@@ -47,9 +51,7 @@ def add_to_shard(i, filename):
4751
# so we'll add all the jobs that need this machine to the 0th worker
4852
add_to_shard(0, filename)
4953

50-
all_other_files = list(
51-
filter(lambda x: x not in needs_gpu_nvidia_small_multi, all_files)
52-
)
54+
all_other_files = [x for x in all_files if x not in needs_gpu_nvidia_small_multi]
5355

5456
sorted_files = sorted(all_other_files, key=get_duration, reverse=True,)
5557

@@ -58,23 +60,23 @@ def add_to_shard(i, filename):
5860
0
5961
]
6062
add_to_shard(min_shard_index, filename)
61-
return list(map(lambda x: x[1], sharded_files))
63+
return [x[1] for x in sharded_files]
6264

6365

64-
def remove_other_files(all_files, files_to_run):
66+
def remove_other_files(all_files, files_to_run) -> None:
6567
for file in all_files:
6668
if file not in files_to_run:
6769
remove_runnable_code(file, file)
6870

6971

70-
def main():
72+
def main() -> None:
7173
num_shards = int(os.environ.get("NUM_WORKERS", 20))
7274
shard_num = int(os.environ.get("WORKER_ID", 0))
7375

7476
all_files = get_all_files()
7577
files_to_run = calculate_shards(all_files, num_shards=num_shards)[shard_num]
7678
remove_other_files(all_files, files_to_run)
77-
stripped_file_names = list(map(lambda x: Path(x).stem, files_to_run))
79+
stripped_file_names = [Path(x).stem for x in files_to_run]
7880
print(" ".join(stripped_file_names))
7981

8082

.jenkins/test_files_to_run.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env python
2+
from get_files_to_run import get_all_files, calculate_shards
3+
from unittest import TestCase, main
4+
from functools import reduce
5+
6+
class TestSharding(TestCase):
7+
def test_no_sharding(self):
8+
all_files=get_all_files()
9+
sharded_files = calculate_shards(all_files, 1)
10+
self.assertSetEqual(set(all_files), set(sharded_files[0]))
11+
12+
def test_sharding(self, num_shards=20):
13+
all_files=get_all_files()
14+
sharded_files = map(set, calculate_shards(all_files, num_shards))
15+
self.assertSetEqual(set(all_files), reduce(lambda x,y: x.union(y), sharded_files, set()))
16+
17+
18+
19+
if __name__ == "__main__":
20+
main()

0 commit comments

Comments
 (0)