Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 35168b7

Browse files
committed
Batch SCCs for parallel processing
1 parent 9e76d34 commit 35168b7

2 files changed

Lines changed: 103 additions & 45 deletions

File tree

mypy/build.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,12 @@
7474
read_bytes,
7575
read_int,
7676
read_int_list,
77-
read_int_opt,
7877
read_str,
7978
read_str_list,
8079
read_str_opt,
8180
write_bytes,
8281
write_int,
8382
write_int_list,
84-
write_int_opt,
8583
write_json_value,
8684
write_str,
8785
write_str_list,
@@ -212,6 +210,10 @@
212210
"https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules"
213211
)
214212

213+
# Padding when estimating how much time it will take to process a file. This is to avoid
214+
# situations where 100 empty __init__.py files cost less than 1 trivial module.
215+
MIN_SIZE_HINT: Final = 256
216+
215217

216218
class SCC:
217219
"""A simple class that represents a strongly connected component (import cycle)."""
@@ -240,7 +242,7 @@ def __init__(
240242
self.direct_dependents: list[int] = []
241243
# Rough estimate of how much time processing this SCC will take, this
242244
# is used for more efficient scheduling across multiple build workers.
243-
self.size_hint: int = 0
245+
self.size_hint: int = MIN_SIZE_HINT
244246

245247

246248
# TODO: Get rid of BuildResult. We might as well return a BuildManager.
@@ -450,7 +452,7 @@ def connect(wc: WorkerClient, data: bytes) -> None:
450452
if not worker.connected:
451453
continue
452454
try:
453-
send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={}))
455+
send(worker.conn, SccRequestMessage(scc_ids=[], import_errors={}, mod_data={}))
454456
except (OSError, IPCException):
455457
pass
456458
for worker in workers:
@@ -968,6 +970,8 @@ def __init__(
968970
# Stale SCCs that are queued for processing. Each tuple contains SCC size hint,
969971
# SCC adding order (tie-breaker), and the SCC itself.
970972
self.scc_queue: list[tuple[int, int, SCC]] = []
973+
# Total size hint for SCCs currently in queue.
974+
self.size_in_queue: int = 0
971975
# SCCs that have been fully processed.
972976
self.done_sccs: set[int] = set()
973977
# Parallel build workers, list is empty for in-process type-checking.
@@ -1097,6 +1101,9 @@ def parse_parallel(self, sequential_states: list[State], parallel_states: list[S
10971101
state.semantic_analysis_pass1()
10981102
self.ast_cache[state.id] = (state.tree, state.early_errors)
10991103
self.modules[state.id] = state.tree
1104+
assert state.tree is not None
1105+
if state.tree.raw_data is not None:
1106+
state.size_hint = len(state.tree.raw_data.defs) + MIN_SIZE_HINT
11001107
state.check_blockers()
11011108
state.setup_errors()
11021109

@@ -1333,7 +1340,11 @@ def receive_worker_message(self, idx: int) -> ReadBuffer:
13331340
try:
13341341
return receive(self.workers[idx].conn)
13351342
except OSError as exc:
1336-
exit_code = self.workers[idx].proc.poll()
1343+
try:
1344+
# Give worker process a chance to actually terminate before reporting.
1345+
exit_code = self.workers[idx].proc.wait(timeout=WORKER_SHUTDOWN_TIMEOUT)
1346+
except TimeoutError:
1347+
exit_code = None
13371348
exit_status = f"exit code {exit_code}" if exit_code is not None else "still running"
13381349
raise OSError(
13391350
f"Worker {idx} disconnected before sending data ({exit_status})"
@@ -1346,24 +1357,57 @@ def submit(self, graph: Graph, sccs: list[SCC]) -> None:
13461357
else:
13471358
self.scc_queue.extend([(0, 0, scc) for scc in sccs])
13481359

1360+
def get_scc_batch(self, max_size_in_batch: int) -> list[SCC]:
1361+
"""Get a batch of SCCs from queue to submit to a worker.
1362+
1363+
We batch SCCs to avoid communication overhead, but to avoid
1364+
long poles, we limit fraction of work per worker.
1365+
"""
1366+
batch: list[SCC] = []
1367+
size_in_batch = 0
1368+
while self.scc_queue and (
1369+
# Three notes keep in mind here:
1370+
# * Heap key is *negative* size (so that larger SCCs appear first).
1371+
# * Each batch must have at least one item.
1372+
# * Adding another SCC to batch should not exceed maximum allowed size.
1373+
size_in_batch - self.scc_queue[0][0] <= max_size_in_batch
1374+
or not batch
1375+
):
1376+
size_key, _, scc = heappop(self.scc_queue)
1377+
size_in_batch -= size_key
1378+
self.size_in_queue += size_key
1379+
batch.append(scc)
1380+
return batch
1381+
1382+
def max_batch_size(self) -> int:
1383+
batch_frac = 1 / len(self.workers)
1384+
if sys.platform == "linux":
1385+
# Linux is good with socket roundtrip latency, so we can use
1386+
# more fine-grained batches.
1387+
batch_frac /= 2
1388+
return int(self.size_in_queue * batch_frac)
1389+
13491390
def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None:
13501391
if sccs is not None:
13511392
for scc in sccs:
13521393
heappush(self.scc_queue, (-scc.size_hint, self.queue_order, scc))
1394+
self.size_in_queue += scc.size_hint
13531395
self.queue_order += 1
1396+
max_size_in_batch = self.max_batch_size()
13541397
while self.scc_queue and self.free_workers:
13551398
idx = self.free_workers.pop()
1356-
_, _, scc = heappop(self.scc_queue)
1399+
scc_batch = self.get_scc_batch(max_size_in_batch)
13571400
import_errors = {
13581401
mod_id: self.errors.recorded[path]
1402+
for scc in scc_batch
13591403
for mod_id in scc.mod_ids
13601404
if (path := graph[mod_id].xpath) in self.errors.recorded
13611405
}
13621406
t0 = time.time()
13631407
send(
13641408
self.workers[idx].conn,
13651409
SccRequestMessage(
1366-
scc_id=scc.id,
1410+
scc_ids=[scc.id for scc in scc_batch],
13671411
import_errors=import_errors,
13681412
mod_data={
13691413
mod_id: (
@@ -1373,11 +1417,12 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
13731417
graph[mod_id].suppressed_deps_opts(),
13741418
tree.raw_data if (tree := graph[mod_id].tree) else None,
13751419
)
1420+
for scc in scc_batch
13761421
for mod_id in scc.mod_ids
13771422
},
13781423
),
13791424
)
1380-
self.add_stats(scc_send_time=time.time() - t0)
1425+
self.add_stats(scc_requests_sent=1, scc_send_time=time.time() - t0)
13811426

13821427
def wait_for_done(self, graph: Graph) -> tuple[list[SCC], bool, dict[str, ModuleResult]]:
13831428
"""Wait for a stale SCC processing to finish.
@@ -1414,13 +1459,12 @@ def wait_for_done_workers(
14141459
if not data.is_interface:
14151460
# Mark worker as free after it finished checking implementation.
14161461
self.free_workers.add(idx)
1417-
scc_id = data.scc_id
14181462
if data.blocker is not None:
14191463
raise data.blocker
14201464
assert data.result is not None
14211465
results.update(data.result)
14221466
if data.is_interface:
1423-
done_sccs.append(self.scc_by_id[scc_id])
1467+
done_sccs.extend([self.scc_by_id[scc_id] for scc_id in data.scc_ids])
14241468
self.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1)
14251469
self.submit_to_workers(graph) # advance after some workers are free.
14261470
return (
@@ -2727,7 +2771,7 @@ def new_state(
27272771
# import pkg.mod
27282772
if exist_removed_submodules(dependencies, manager):
27292773
state.needs_parse = True # Same as above, the current state is stale anyway.
2730-
state.size_hint = meta.size
2774+
state.size_hint = meta.size + MIN_SIZE_HINT
27312775
else:
27322776
# When doing a fine-grained cache load, pretend we only
27332777
# know about modules that have cache information and defer
@@ -3092,7 +3136,7 @@ def get_source(self) -> str:
30923136
self.parse_inline_configuration(source)
30933137
self.check_for_invalid_options()
30943138

3095-
self.size_hint = len(source)
3139+
self.size_hint = len(source) + MIN_SIZE_HINT
30963140
self.time_spent_us += time_spent_us(t0)
30973141
return source
30983142

@@ -3157,6 +3201,14 @@ def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None =
31573201
self.check_blockers()
31583202

31593203
manager.ast_cache[self.id] = (self.tree, self.early_errors)
3204+
assert self.tree is not None
3205+
if self.tree.raw_data is not None:
3206+
# Size of serialized tree is a better proxy for file complexity than
3207+
# file size, so we use that when possible. Note that we rely on lucky
3208+
# coincidence that serialized tree size has same order of magnitude as
3209+
# file size, so we don't need any normalization factor in situations
3210+
# where parsed and cached files are mixed.
3211+
self.size_hint = len(self.tree.raw_data.defs) + MIN_SIZE_HINT
31603212
self.setup_errors()
31613213

31623214
def setup_errors(self) -> None:
@@ -5054,26 +5106,26 @@ def write(self, buf: WriteBuffer) -> None:
50545106

50555107
class SccRequestMessage(IPCMessage):
50565108
"""
5057-
A message representing a request to type check an SCC.
5109+
A message representing a request to type check a batch of SCCs.
50585110
5059-
If scc_id is None, then it means that the coordinator requested a shutdown.
5111+
If scc_ids is empty, then it means that the coordinator requested a shutdown.
50605112
"""
50615113

50625114
def __init__(
50635115
self,
50645116
*,
5065-
scc_id: int | None,
5117+
scc_ids: list[int],
50665118
import_errors: dict[str, list[ErrorInfo]],
50675119
mod_data: dict[str, tuple[bytes, FileRawData | None]],
50685120
) -> None:
5069-
self.scc_id = scc_id
5121+
self.scc_ids = scc_ids
50705122
self.import_errors = import_errors
50715123
self.mod_data = mod_data
50725124

50735125
@classmethod
50745126
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
50755127
return SccRequestMessage(
5076-
scc_id=read_int_opt(buf),
5128+
scc_ids=read_int_list(buf),
50775129
import_errors={
50785130
read_str(buf): [ErrorInfo.read(buf) for _ in range(read_int_bare(buf))]
50795131
for _ in range(read_int_bare(buf))
@@ -5089,7 +5141,7 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage:
50895141

50905142
def write(self, buf: WriteBuffer) -> None:
50915143
write_tag(buf, SCC_REQUEST_MESSAGE)
5092-
write_int_opt(buf, self.scc_id)
5144+
write_int_list(buf, self.scc_ids)
50935145
write_int_bare(buf, len(self.import_errors))
50945146
for path, errors in self.import_errors.items():
50955147
write_str(buf, path)
@@ -5130,17 +5182,17 @@ def write(self, buf: WriteBuffer) -> None:
51305182

51315183
class SccResponseMessage(IPCMessage):
51325184
"""
5133-
A message representing a result of type checking an SCC.
5185+
A message representing a result of type checking a batch of SCCs.
51345186
51355187
Only one of `result` or `blocker` can be non-None. The latter means there was
5136-
a blocking error while type checking the SCC. The `is_interface` flag indicates
5188+
a blocking error while type checking the SCCs. The `is_interface` flag indicates
51375189
whether this is a result for interface or implementation phase of type-checking.
51385190
"""
51395191

51405192
def __init__(
51415193
self,
51425194
*,
5143-
scc_id: int,
5195+
scc_ids: list[int],
51445196
is_interface: bool,
51455197
result: dict[str, ModuleResult] | None = None,
51465198
blocker: CompileError | None = None,
@@ -5149,26 +5201,26 @@ def __init__(
51495201
assert blocker is None
51505202
if blocker is not None:
51515203
assert result is None
5152-
self.scc_id = scc_id
5204+
self.scc_ids = scc_ids
51535205
self.is_interface = is_interface
51545206
self.result = result
51555207
self.blocker = blocker
51565208

51575209
@classmethod
51585210
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
5159-
scc_id = read_int(buf)
5211+
scc_ids = read_int_list(buf)
51605212
is_interface = read_bool(buf)
51615213
tag = read_tag(buf)
51625214
if tag == LITERAL_NONE:
51635215
return SccResponseMessage(
5164-
scc_id=scc_id,
5216+
scc_ids=scc_ids,
51655217
is_interface=is_interface,
51665218
blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)),
51675219
)
51685220
else:
51695221
assert tag == DICT_STR_GEN
51705222
return SccResponseMessage(
5171-
scc_id=scc_id,
5223+
scc_ids=scc_ids,
51725224
is_interface=is_interface,
51735225
result={
51745226
read_str_bare(buf): ModuleResult.read(buf) for _ in range(read_int_bare(buf))
@@ -5177,7 +5229,7 @@ def read(cls, buf: ReadBuffer) -> SccResponseMessage:
51775229

51785230
def write(self, buf: WriteBuffer) -> None:
51795231
write_tag(buf, SCC_RESPONSE_MESSAGE)
5180-
write_int(buf, self.scc_id)
5232+
write_int_list(buf, self.scc_ids)
51815233
write_bool(buf, self.is_interface)
51825234
if self.result is None:
51835235
assert self.blocker is not None

0 commit comments

Comments
 (0)