Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit db2faa7

Browse files
authored
Batch SCCs for parallel processing (#21287)
This is a follow-up for #21119 Implementation mostly straightforward. Some comments: * After experimenting with self-check and torch, it looks like 1/N for Mac and 1/2N for Linux should be a safe batch size limit. Because of Zipf's Law together with "rounding down" logic, the average batch size is significantly smaller than the limit. For example, with six workers on Linux, a single worker rarely takes more than 5% of the queue. * I add padding to module `size_hint`. Apparently, there are many empty `__init__.py` files, but processing an empty file still costs some non-trivial amount of time. * I use size of serialized tree as a better proxy for module complexity. This is probably not very important, but will avoid weird cases where we have a file with lots of comments. This will also account for conditional function body stripping. Btw we should probably serialize all dosctrings in `ast_serialize` as `"<docstring>"` or similar (we can't skip them completely). * I write cache for interface after each SCC in batch, while cache for all implementations is written in one go. There is no deep logic behind this, this is the simplest way to do it because of how code is currently structured. If needed for performance reasons, this can be tweaked one or other way (i.e. fewer larger cache commits of more smaller cache commits) with not too much effort. * While working on this I accidentally noticed that our crash detector always reports a worker as "still running", so I added a little wait in case of a crash.
1 parent 2b136f0 commit db2faa7

2 files changed

Lines changed: 102 additions & 45 deletions

File tree

mypy/build.py

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,12 @@
7474
read_bytes,
7575
read_int,
7676
read_int_list,
77-
read_int_opt,
7877
read_str,
7978
read_str_list,
8079
read_str_opt,
8180
write_bytes,
8281
write_int,
8382
write_int_list,
84-
write_int_opt,
8583
write_json_value,
8684
write_str,
8785
write_str_list,
@@ -212,6 +210,10 @@
212210
"https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules"
213211
)
214212

213+
# Padding when estimating how much time it will take to process a file. This is to avoid
214+
# situations where 100 empty __init__.py files cost less than 1 trivial module.
215+
MIN_SIZE_HINT: Final = 256
216+
215217

216218
class SCC:
217219
"""A simple class that represents a strongly connected component (import cycle)."""
@@ -240,7 +242,7 @@ def __init__(
240242
self.direct_dependents: list[int] = []
241243
# Rough estimate of how much time processing this SCC will take, this
242244
# is used for more efficient scheduling across multiple build workers.
243-
self.size_hint: int = 0
245+
self.size_hint: int = MIN_SIZE_HINT
244246

245247

246248
# TODO: Get rid of BuildResult. We might as well return a BuildManager.
@@ -450,7 +452,7 @@ def connect(wc: WorkerClient, data: bytes) -> None:
450452
if not worker.connected:
451453
continue
452454
try:
453-
send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={}))
455+
send(worker.conn, SccRequestMessage(scc_ids=[], import_errors={}, mod_data={}))
454456
except (OSError, IPCException):
455457
pass
456458
for worker in workers:
@@ -968,6 +970,8 @@ def __init__(
968970
# Stale SCCs that are queued for processing. Each tuple contains SCC size hint,
969971
# SCC adding order (tie-breaker), and the SCC itself.
970972
self.scc_queue: list[tuple[int, int, SCC]] = []
973+
# Total size hint for SCCs currently in queue.
974+
self.size_in_queue: int = 0
971975
# SCCs that have been fully processed.
972976
self.done_sccs: set[int] = set()
973977
# Parallel build workers, list is empty for in-process type-checking.
@@ -1081,6 +1085,8 @@ def parse_parallel(self, sequential_states: list[State], parallel_states: list[S
10811085
state.semantic_analysis_pass1()
10821086
self.ast_cache[state.id] = (state.tree, state.early_errors, state.source_hash)
10831087
self.modules[state.id] = state.tree
1088+
if state.tree.raw_data is not None:
1089+
state.size_hint = len(state.tree.raw_data.defs) + MIN_SIZE_HINT
10841090
state.check_blockers()
10851091
state.setup_errors()
10861092

@@ -1362,7 +1368,11 @@ def receive_worker_message(self, idx: int) -> ReadBuffer:
13621368
try:
13631369
return receive(self.workers[idx].conn)
13641370
except OSError as exc:
1365-
exit_code = self.workers[idx].proc.poll()
1371+
try:
1372+
# Give worker process a chance to actually terminate before reporting.
1373+
exit_code = self.workers[idx].proc.wait(timeout=WORKER_SHUTDOWN_TIMEOUT)
1374+
except TimeoutError:
1375+
exit_code = None
13661376
exit_status = f"exit code {exit_code}" if exit_code is not None else "still running"
13671377
raise OSError(
13681378
f"Worker {idx} disconnected before sending data ({exit_status})"
@@ -1375,24 +1385,57 @@ def submit(self, graph: Graph, sccs: list[SCC]) -> None:
13751385
else:
13761386
self.scc_queue.extend([(0, 0, scc) for scc in sccs])
13771387

1388+
def get_scc_batch(self, max_size_in_batch: int) -> list[SCC]:
1389+
"""Get a batch of SCCs from queue to submit to a worker.
1390+
1391+
We batch SCCs to avoid communication overhead, but to avoid
1392+
long poles, we limit fraction of work per worker.
1393+
"""
1394+
batch: list[SCC] = []
1395+
size_in_batch = 0
1396+
while self.scc_queue and (
1397+
# Three notes keep in mind here:
1398+
# * Heap key is *negative* size (so that larger SCCs appear first).
1399+
# * Each batch must have at least one item.
1400+
# * Adding another SCC to batch should not exceed maximum allowed size.
1401+
size_in_batch - self.scc_queue[0][0] <= max_size_in_batch
1402+
or not batch
1403+
):
1404+
size_key, _, scc = heappop(self.scc_queue)
1405+
size_in_batch -= size_key
1406+
self.size_in_queue += size_key
1407+
batch.append(scc)
1408+
return batch
1409+
1410+
def max_batch_size(self) -> int:
1411+
batch_frac = 1 / len(self.workers)
1412+
if sys.platform == "linux":
1413+
# Linux is good with socket roundtrip latency, so we can use
1414+
# more fine-grained batches.
1415+
batch_frac /= 2
1416+
return int(self.size_in_queue * batch_frac)
1417+
13781418
def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None:
13791419
if sccs is not None:
13801420
for scc in sccs:
13811421
heappush(self.scc_queue, (-scc.size_hint, self.queue_order, scc))
1422+
self.size_in_queue += scc.size_hint
13821423
self.queue_order += 1
1424+
max_size_in_batch = self.max_batch_size()
13831425
while self.scc_queue and self.free_workers:
13841426
idx = self.free_workers.pop()
1385-
_, _, scc = heappop(self.scc_queue)
1427+
scc_batch = self.get_scc_batch(max_size_in_batch)
13861428
import_errors = {
13871429
mod_id: self.errors.recorded[path]
1430+
for scc in scc_batch
13881431
for mod_id in scc.mod_ids
13891432
if (path := graph[mod_id].xpath) in self.errors.recorded
13901433
}
13911434
t0 = time.time()
13921435
send(
13931436
self.workers[idx].conn,
13941437
SccRequestMessage(
1395-
scc_id=scc.id,
1438+
scc_ids=[scc.id for scc in scc_batch],
13961439
import_errors=import_errors,
13971440
mod_data={
13981441
mod_id: (
@@ -1402,11 +1445,12 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
14021445
graph[mod_id].suppressed_deps_opts(),
14031446
tree.raw_data if (tree := graph[mod_id].tree) else None,
14041447
)
1448+
for scc in scc_batch
14051449
for mod_id in scc.mod_ids
14061450
},
14071451
),
14081452
)
1409-
self.add_stats(scc_send_time=time.time() - t0)
1453+
self.add_stats(scc_requests_sent=1, scc_send_time=time.time() - t0)
14101454

14111455
def wait_for_done(self, graph: Graph) -> tuple[list[SCC], bool, dict[str, ModuleResult]]:
14121456
"""Wait for a stale SCC processing to finish.
@@ -1443,13 +1487,12 @@ def wait_for_done_workers(
14431487
if not data.is_interface:
14441488
# Mark worker as free after it finished checking implementation.
14451489
self.free_workers.add(idx)
1446-
scc_id = data.scc_id
14471490
if data.blocker is not None:
14481491
raise data.blocker
14491492
assert data.result is not None
14501493
results.update(data.result)
14511494
if data.is_interface:
1452-
done_sccs.append(self.scc_by_id[scc_id])
1495+
done_sccs.extend([self.scc_by_id[scc_id] for scc_id in data.scc_ids])
14531496
self.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1)
14541497
self.submit_to_workers(graph) # advance after some workers are free.
14551498
return (
@@ -2756,7 +2799,7 @@ def new_state(
27562799
# import pkg.mod
27572800
if exist_removed_submodules(dependencies, manager):
27582801
state.needs_parse = True # Same as above, the current state is stale anyway.
2759-
state.size_hint = meta.size
2802+
state.size_hint = meta.size + MIN_SIZE_HINT
27602803
else:
27612804
# When doing a fine-grained cache load, pretend we only
27622805
# know about modules that have cache information and defer
@@ -3120,7 +3163,7 @@ def get_source(self) -> str:
31203163

31213164
self.parse_inline_configuration(source)
31223165

3123-
self.size_hint = len(source)
3166+
self.size_hint = len(source) + MIN_SIZE_HINT
31243167
self.time_spent_us += time_spent_us(t0)
31253168
return source
31263169

@@ -3195,6 +3238,14 @@ def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None =
31953238
self.check_blockers()
31963239

31973240
manager.ast_cache[self.id] = (self.tree, self.early_errors, self.source_hash)
3241+
assert self.tree is not None
3242+
if self.tree.raw_data is not None:
3243+
# Size of serialized tree is a better proxy for file complexity than
3244+
# file size, so we use that when possible. Note that we rely on lucky
3245+
# coincidence that serialized tree size has same order of magnitude as
3246+
# file size, so we don't need any normalization factor in situations
3247+
# where parsed and cached files are mixed.
3248+
self.size_hint = len(self.tree.raw_data.defs) + MIN_SIZE_HINT
31983249
self.setup_errors()
31993250

32003251
def setup_errors(self) -> None:
@@ -5084,26 +5135,26 @@ def write(self, buf: WriteBuffer) -> None:
50845135

50855136
class SccRequestMessage(IPCMessage):
50865137
"""
5087-
A message representing a request to type check an SCC.
5138+
A message representing a request to type check a batch of SCCs.
50885139
5089-
If scc_id is None, then it means that the coordinator requested a shutdown.
5140+
If scc_ids is empty, then it means that the coordinator requested a shutdown.
50905141
"""
50915142

50925143
def __init__(
50935144
self,
50945145
*,
5095-
scc_id: int | None,
5146+
scc_ids: list[int],
50965147
import_errors: dict[str, list[ErrorInfo]],
50975148
mod_data: dict[str, tuple[bytes, FileRawData | None]],
50985149
) -> None:
5099-
self.scc_id = scc_id
5150+
self.scc_ids = scc_ids
51005151
self.import_errors = import_errors
51015152
self.mod_data = mod_data
51025153

51035154
@classmethod
51045155
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
51055156
return SccRequestMessage(
5106-
scc_id=read_int_opt(buf),
5157+
scc_ids=read_int_list(buf),
51075158
import_errors={
51085159
read_str(buf): [ErrorInfo.read(buf) for _ in range(read_int_bare(buf))]
51095160
for _ in range(read_int_bare(buf))
@@ -5119,7 +5170,7 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage:
51195170

51205171
def write(self, buf: WriteBuffer) -> None:
51215172
write_tag(buf, SCC_REQUEST_MESSAGE)
5122-
write_int_opt(buf, self.scc_id)
5173+
write_int_list(buf, self.scc_ids)
51235174
write_int_bare(buf, len(self.import_errors))
51245175
for path, errors in self.import_errors.items():
51255176
write_str(buf, path)
@@ -5160,17 +5211,17 @@ def write(self, buf: WriteBuffer) -> None:
51605211

51615212
class SccResponseMessage(IPCMessage):
51625213
"""
5163-
A message representing a result of type checking an SCC.
5214+
A message representing a result of type checking a batch of SCCs.
51645215
51655216
Only one of `result` or `blocker` can be non-None. The latter means there was
5166-
a blocking error while type checking the SCC. The `is_interface` flag indicates
5217+
a blocking error while type checking the SCCs. The `is_interface` flag indicates
51675218
whether this is a result for interface or implementation phase of type-checking.
51685219
"""
51695220

51705221
def __init__(
51715222
self,
51725223
*,
5173-
scc_id: int,
5224+
scc_ids: list[int],
51745225
is_interface: bool,
51755226
result: dict[str, ModuleResult] | None = None,
51765227
blocker: CompileError | None = None,
@@ -5179,26 +5230,26 @@ def __init__(
51795230
assert blocker is None
51805231
if blocker is not None:
51815232
assert result is None
5182-
self.scc_id = scc_id
5233+
self.scc_ids = scc_ids
51835234
self.is_interface = is_interface
51845235
self.result = result
51855236
self.blocker = blocker
51865237

51875238
@classmethod
51885239
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
5189-
scc_id = read_int(buf)
5240+
scc_ids = read_int_list(buf)
51905241
is_interface = read_bool(buf)
51915242
tag = read_tag(buf)
51925243
if tag == LITERAL_NONE:
51935244
return SccResponseMessage(
5194-
scc_id=scc_id,
5245+
scc_ids=scc_ids,
51955246
is_interface=is_interface,
51965247
blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)),
51975248
)
51985249
else:
51995250
assert tag == DICT_STR_GEN
52005251
return SccResponseMessage(
5201-
scc_id=scc_id,
5252+
scc_ids=scc_ids,
52025253
is_interface=is_interface,
52035254
result={
52045255
read_str_bare(buf): ModuleResult.read(buf) for _ in range(read_int_bare(buf))
@@ -5207,7 +5258,7 @@ def read(cls, buf: ReadBuffer) -> SccResponseMessage:
52075258

52085259
def write(self, buf: WriteBuffer) -> None:
52095260
write_tag(buf, SCC_RESPONSE_MESSAGE)
5210-
write_int(buf, self.scc_id)
5261+
write_int_list(buf, self.scc_ids)
52115262
write_bool(buf, self.is_interface)
52125263
if self.result is None:
52135264
assert self.blocker is not None

0 commit comments

Comments
 (0)