7474 read_bytes ,
7575 read_int ,
7676 read_int_list ,
77- read_int_opt ,
7877 read_str ,
7978 read_str_list ,
8079 read_str_opt ,
8180 write_bytes ,
8281 write_int ,
8382 write_int_list ,
84- write_int_opt ,
8583 write_json_value ,
8684 write_str ,
8785 write_str_list ,
212210 "https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules"
213211)
214212
213+ # Padding when estimating how much time it will take to process a file. This is to avoid
214+ # situations where 100 empty __init__.py files cost less than 1 trivial module.
215+ MIN_SIZE_HINT : Final = 256
216+
215217
216218class SCC :
217219 """A simple class that represents a strongly connected component (import cycle)."""
@@ -240,7 +242,7 @@ def __init__(
240242 self .direct_dependents : list [int ] = []
241243 # Rough estimate of how much time processing this SCC will take, this
242244 # is used for more efficient scheduling across multiple build workers.
243- self .size_hint : int = 0
245+ self .size_hint : int = MIN_SIZE_HINT
244246
245247
246248# TODO: Get rid of BuildResult. We might as well return a BuildManager.
@@ -450,7 +452,7 @@ def connect(wc: WorkerClient, data: bytes) -> None:
450452 if not worker .connected :
451453 continue
452454 try :
453- send (worker .conn , SccRequestMessage (scc_id = None , import_errors = {}, mod_data = {}))
455+ send (worker .conn , SccRequestMessage (scc_ids = [] , import_errors = {}, mod_data = {}))
454456 except (OSError , IPCException ):
455457 pass
456458 for worker in workers :
@@ -968,6 +970,8 @@ def __init__(
968970 # Stale SCCs that are queued for processing. Each tuple contains SCC size hint,
969971 # SCC adding order (tie-breaker), and the SCC itself.
970972 self .scc_queue : list [tuple [int , int , SCC ]] = []
973+ # Total size hint for SCCs currently in queue.
974+ self .size_in_queue : int = 0
971975 # SCCs that have been fully processed.
972976 self .done_sccs : set [int ] = set ()
973977 # Parallel build workers, list is empty for in-process type-checking.
@@ -1081,6 +1085,8 @@ def parse_parallel(self, sequential_states: list[State], parallel_states: list[S
10811085 state .semantic_analysis_pass1 ()
10821086 self .ast_cache [state .id ] = (state .tree , state .early_errors , state .source_hash )
10831087 self .modules [state .id ] = state .tree
1088+ if state .tree .raw_data is not None :
1089+ state .size_hint = len (state .tree .raw_data .defs ) + MIN_SIZE_HINT
10841090 state .check_blockers ()
10851091 state .setup_errors ()
10861092
@@ -1362,7 +1368,11 @@ def receive_worker_message(self, idx: int) -> ReadBuffer:
13621368 try :
13631369 return receive (self .workers [idx ].conn )
13641370 except OSError as exc :
1365- exit_code = self .workers [idx ].proc .poll ()
1371+ try :
1372+ # Give worker process a chance to actually terminate before reporting.
1373+ exit_code = self .workers [idx ].proc .wait (timeout = WORKER_SHUTDOWN_TIMEOUT )
1374+ except TimeoutError :
1375+ exit_code = None
13661376 exit_status = f"exit code { exit_code } " if exit_code is not None else "still running"
13671377 raise OSError (
13681378 f"Worker { idx } disconnected before sending data ({ exit_status } )"
@@ -1375,24 +1385,57 @@ def submit(self, graph: Graph, sccs: list[SCC]) -> None:
13751385 else :
13761386 self .scc_queue .extend ([(0 , 0 , scc ) for scc in sccs ])
13771387
1388+ def get_scc_batch (self , max_size_in_batch : int ) -> list [SCC ]:
1389+ """Get a batch of SCCs from queue to submit to a worker.
1390+
1391+ We batch SCCs to avoid communication overhead, but to avoid
1392+ long poles, we limit fraction of work per worker.
1393+ """
1394+ batch : list [SCC ] = []
1395+ size_in_batch = 0
1396+ while self .scc_queue and (
1397+ # Three notes keep in mind here:
1398+ # * Heap key is *negative* size (so that larger SCCs appear first).
1399+ # * Each batch must have at least one item.
1400+ # * Adding another SCC to batch should not exceed maximum allowed size.
1401+ size_in_batch - self .scc_queue [0 ][0 ] <= max_size_in_batch
1402+ or not batch
1403+ ):
1404+ size_key , _ , scc = heappop (self .scc_queue )
1405+ size_in_batch -= size_key
1406+ self .size_in_queue += size_key
1407+ batch .append (scc )
1408+ return batch
1409+
1410+ def max_batch_size (self ) -> int :
1411+ batch_frac = 1 / len (self .workers )
1412+ if sys .platform == "linux" :
1413+ # Linux is good with socket roundtrip latency, so we can use
1414+ # more fine-grained batches.
1415+ batch_frac /= 2
1416+ return int (self .size_in_queue * batch_frac )
1417+
13781418 def submit_to_workers (self , graph : Graph , sccs : list [SCC ] | None = None ) -> None :
13791419 if sccs is not None :
13801420 for scc in sccs :
13811421 heappush (self .scc_queue , (- scc .size_hint , self .queue_order , scc ))
1422+ self .size_in_queue += scc .size_hint
13821423 self .queue_order += 1
1424+ max_size_in_batch = self .max_batch_size ()
13831425 while self .scc_queue and self .free_workers :
13841426 idx = self .free_workers .pop ()
1385- _ , _ , scc = heappop ( self .scc_queue )
1427+ scc_batch = self .get_scc_batch ( max_size_in_batch )
13861428 import_errors = {
13871429 mod_id : self .errors .recorded [path ]
1430+ for scc in scc_batch
13881431 for mod_id in scc .mod_ids
13891432 if (path := graph [mod_id ].xpath ) in self .errors .recorded
13901433 }
13911434 t0 = time .time ()
13921435 send (
13931436 self .workers [idx ].conn ,
13941437 SccRequestMessage (
1395- scc_id = scc .id ,
1438+ scc_ids = [ scc .id for scc in scc_batch ] ,
13961439 import_errors = import_errors ,
13971440 mod_data = {
13981441 mod_id : (
@@ -1402,11 +1445,12 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
14021445 graph [mod_id ].suppressed_deps_opts (),
14031446 tree .raw_data if (tree := graph [mod_id ].tree ) else None ,
14041447 )
1448+ for scc in scc_batch
14051449 for mod_id in scc .mod_ids
14061450 },
14071451 ),
14081452 )
1409- self .add_stats (scc_send_time = time .time () - t0 )
1453+ self .add_stats (scc_requests_sent = 1 , scc_send_time = time .time () - t0 )
14101454
14111455 def wait_for_done (self , graph : Graph ) -> tuple [list [SCC ], bool , dict [str , ModuleResult ]]:
14121456 """Wait for a stale SCC processing to finish.
@@ -1443,13 +1487,12 @@ def wait_for_done_workers(
14431487 if not data .is_interface :
14441488 # Mark worker as free after it finished checking implementation.
14451489 self .free_workers .add (idx )
1446- scc_id = data .scc_id
14471490 if data .blocker is not None :
14481491 raise data .blocker
14491492 assert data .result is not None
14501493 results .update (data .result )
14511494 if data .is_interface :
1452- done_sccs .append ( self .scc_by_id [scc_id ])
1495+ done_sccs .extend ([ self .scc_by_id [scc_id ] for scc_id in data . scc_ids ])
14531496 self .add_stats (scc_wait_time = t1 - t0 , scc_receive_time = time .time () - t1 )
14541497 self .submit_to_workers (graph ) # advance after some workers are free.
14551498 return (
@@ -2756,7 +2799,7 @@ def new_state(
27562799 # import pkg.mod
27572800 if exist_removed_submodules (dependencies , manager ):
27582801 state .needs_parse = True # Same as above, the current state is stale anyway.
2759- state .size_hint = meta .size
2802+ state .size_hint = meta .size + MIN_SIZE_HINT
27602803 else :
27612804 # When doing a fine-grained cache load, pretend we only
27622805 # know about modules that have cache information and defer
@@ -3120,7 +3163,7 @@ def get_source(self) -> str:
31203163
31213164 self .parse_inline_configuration (source )
31223165
3123- self .size_hint = len (source )
3166+ self .size_hint = len (source ) + MIN_SIZE_HINT
31243167 self .time_spent_us += time_spent_us (t0 )
31253168 return source
31263169
@@ -3195,6 +3238,14 @@ def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None =
31953238 self .check_blockers ()
31963239
31973240 manager .ast_cache [self .id ] = (self .tree , self .early_errors , self .source_hash )
3241+ assert self .tree is not None
3242+ if self .tree .raw_data is not None :
3243+ # Size of serialized tree is a better proxy for file complexity than
3244+ # file size, so we use that when possible. Note that we rely on lucky
3245+ # coincidence that serialized tree size has same order of magnitude as
3246+ # file size, so we don't need any normalization factor in situations
3247+ # where parsed and cached files are mixed.
3248+ self .size_hint = len (self .tree .raw_data .defs ) + MIN_SIZE_HINT
31983249 self .setup_errors ()
31993250
32003251 def setup_errors (self ) -> None :
@@ -5084,26 +5135,26 @@ def write(self, buf: WriteBuffer) -> None:
50845135
50855136class SccRequestMessage (IPCMessage ):
50865137 """
5087- A message representing a request to type check an SCC .
5138+ A message representing a request to type check a batch of SCCs .
50885139
5089- If scc_id is None , then it means that the coordinator requested a shutdown.
5140+ If scc_ids is empty , then it means that the coordinator requested a shutdown.
50905141 """
50915142
50925143 def __init__ (
50935144 self ,
50945145 * ,
5095- scc_id : int | None ,
5146+ scc_ids : list [ int ] ,
50965147 import_errors : dict [str , list [ErrorInfo ]],
50975148 mod_data : dict [str , tuple [bytes , FileRawData | None ]],
50985149 ) -> None :
5099- self .scc_id = scc_id
5150+ self .scc_ids = scc_ids
51005151 self .import_errors = import_errors
51015152 self .mod_data = mod_data
51025153
51035154 @classmethod
51045155 def read (cls , buf : ReadBuffer ) -> SccRequestMessage :
51055156 return SccRequestMessage (
5106- scc_id = read_int_opt (buf ),
5157+ scc_ids = read_int_list (buf ),
51075158 import_errors = {
51085159 read_str (buf ): [ErrorInfo .read (buf ) for _ in range (read_int_bare (buf ))]
51095160 for _ in range (read_int_bare (buf ))
@@ -5119,7 +5170,7 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage:
51195170
51205171 def write (self , buf : WriteBuffer ) -> None :
51215172 write_tag (buf , SCC_REQUEST_MESSAGE )
5122- write_int_opt (buf , self .scc_id )
5173+ write_int_list (buf , self .scc_ids )
51235174 write_int_bare (buf , len (self .import_errors ))
51245175 for path , errors in self .import_errors .items ():
51255176 write_str (buf , path )
@@ -5160,17 +5211,17 @@ def write(self, buf: WriteBuffer) -> None:
51605211
51615212class SccResponseMessage (IPCMessage ):
51625213 """
5163- A message representing a result of type checking an SCC .
5214+ A message representing a result of type checking a batch of SCCs .
51645215
51655216 Only one of `result` or `blocker` can be non-None. The latter means there was
5166- a blocking error while type checking the SCC . The `is_interface` flag indicates
5217+ a blocking error while type checking the SCCs . The `is_interface` flag indicates
51675218 whether this is a result for interface or implementation phase of type-checking.
51685219 """
51695220
51705221 def __init__ (
51715222 self ,
51725223 * ,
5173- scc_id : int ,
5224+ scc_ids : list [ int ] ,
51745225 is_interface : bool ,
51755226 result : dict [str , ModuleResult ] | None = None ,
51765227 blocker : CompileError | None = None ,
@@ -5179,26 +5230,26 @@ def __init__(
51795230 assert blocker is None
51805231 if blocker is not None :
51815232 assert result is None
5182- self .scc_id = scc_id
5233+ self .scc_ids = scc_ids
51835234 self .is_interface = is_interface
51845235 self .result = result
51855236 self .blocker = blocker
51865237
51875238 @classmethod
51885239 def read (cls , buf : ReadBuffer ) -> SccResponseMessage :
5189- scc_id = read_int (buf )
5240+ scc_ids = read_int_list (buf )
51905241 is_interface = read_bool (buf )
51915242 tag = read_tag (buf )
51925243 if tag == LITERAL_NONE :
51935244 return SccResponseMessage (
5194- scc_id = scc_id ,
5245+ scc_ids = scc_ids ,
51955246 is_interface = is_interface ,
51965247 blocker = CompileError (read_str_list (buf ), read_bool (buf ), read_str_opt (buf )),
51975248 )
51985249 else :
51995250 assert tag == DICT_STR_GEN
52005251 return SccResponseMessage (
5201- scc_id = scc_id ,
5252+ scc_ids = scc_ids ,
52025253 is_interface = is_interface ,
52035254 result = {
52045255 read_str_bare (buf ): ModuleResult .read (buf ) for _ in range (read_int_bare (buf ))
@@ -5207,7 +5258,7 @@ def read(cls, buf: ReadBuffer) -> SccResponseMessage:
52075258
52085259 def write (self , buf : WriteBuffer ) -> None :
52095260 write_tag (buf , SCC_RESPONSE_MESSAGE )
5210- write_int (buf , self .scc_id )
5261+ write_int_list (buf , self .scc_ids )
52115262 write_bool (buf , self .is_interface )
52125263 if self .result is None :
52135264 assert self .blocker is not None
0 commit comments