7474 read_bytes ,
7575 read_int ,
7676 read_int_list ,
77- read_int_opt ,
7877 read_str ,
7978 read_str_list ,
8079 read_str_opt ,
8180 write_bytes ,
8281 write_int ,
8382 write_int_list ,
84- write_int_opt ,
8583 write_json_value ,
8684 write_str ,
8785 write_str_list ,
212210 "https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules"
213211)
214212
213+ # Padding when estimating how much time it will take to process a file. This is to avoid
214+ # situations where 100 empty __init__.py files cost less than 1 trivial module.
215+ MIN_SIZE_HINT : Final = 256
216+
215217
216218class SCC :
217219 """A simple class that represents a strongly connected component (import cycle)."""
@@ -240,7 +242,7 @@ def __init__(
240242 self .direct_dependents : list [int ] = []
241243 # Rough estimate of how much time processing this SCC will take, this
242244 # is used for more efficient scheduling across multiple build workers.
243- self .size_hint : int = 0
245+ self .size_hint : int = MIN_SIZE_HINT
244246
245247
246248# TODO: Get rid of BuildResult. We might as well return a BuildManager.
@@ -450,7 +452,7 @@ def connect(wc: WorkerClient, data: bytes) -> None:
450452 if not worker .connected :
451453 continue
452454 try :
453- send (worker .conn , SccRequestMessage (scc_id = None , import_errors = {}, mod_data = {}))
455+ send (worker .conn , SccRequestMessage (scc_ids = [] , import_errors = {}, mod_data = {}))
454456 except (OSError , IPCException ):
455457 pass
456458 for worker in workers :
@@ -968,6 +970,8 @@ def __init__(
968970 # Stale SCCs that are queued for processing. Each tuple contains SCC size hint,
969971 # SCC adding order (tie-breaker), and the SCC itself.
970972 self .scc_queue : list [tuple [int , int , SCC ]] = []
973+ # Total size hint for SCCs currently in queue.
974+ self .size_in_queue : int = 0
971975 # SCCs that have been fully processed.
972976 self .done_sccs : set [int ] = set ()
973977 # Parallel build workers, list is empty for in-process type-checking.
@@ -1097,6 +1101,9 @@ def parse_parallel(self, sequential_states: list[State], parallel_states: list[S
10971101 state .semantic_analysis_pass1 ()
10981102 self .ast_cache [state .id ] = (state .tree , state .early_errors )
10991103 self .modules [state .id ] = state .tree
1104+ assert state .tree is not None
1105+ if state .tree .raw_data is not None :
1106+ state .size_hint = len (state .tree .raw_data .defs ) + MIN_SIZE_HINT
11001107 state .check_blockers ()
11011108 state .setup_errors ()
11021109
@@ -1333,7 +1340,11 @@ def receive_worker_message(self, idx: int) -> ReadBuffer:
13331340 try :
13341341 return receive (self .workers [idx ].conn )
13351342 except OSError as exc :
1336- exit_code = self .workers [idx ].proc .poll ()
1343+ try :
1344+ # Give worker process a chance to actually terminate before reporting.
1345+ exit_code = self .workers [idx ].proc .wait (timeout = WORKER_SHUTDOWN_TIMEOUT )
1346+ except TimeoutError :
1347+ exit_code = None
13371348 exit_status = f"exit code { exit_code } " if exit_code is not None else "still running"
13381349 raise OSError (
13391350 f"Worker { idx } disconnected before sending data ({ exit_status } )"
@@ -1346,24 +1357,57 @@ def submit(self, graph: Graph, sccs: list[SCC]) -> None:
13461357 else :
13471358 self .scc_queue .extend ([(0 , 0 , scc ) for scc in sccs ])
13481359
1360+ def get_scc_batch (self , max_size_in_batch : int ) -> list [SCC ]:
1361+ """Get a batch of SCCs from queue to submit to a worker.
1362+
1363+ We batch SCCs to avoid communication overhead, but to avoid
1364+ long poles, we limit fraction of work per worker.
1365+ """
1366+ batch : list [SCC ] = []
1367+ size_in_batch = 0
1368+ while self .scc_queue and (
1369+ # Three notes keep in mind here:
1370+ # * Heap key is *negative* size (so that larger SCCs appear first).
1371+ # * Each batch must have at least one item.
1372+ # * Adding another SCC to batch should not exceed maximum allowed size.
1373+ size_in_batch - self .scc_queue [0 ][0 ] <= max_size_in_batch
1374+ or not batch
1375+ ):
1376+ size_key , _ , scc = heappop (self .scc_queue )
1377+ size_in_batch -= size_key
1378+ self .size_in_queue += size_key
1379+ batch .append (scc )
1380+ return batch
1381+
1382+ def max_batch_size (self ) -> int :
1383+ batch_frac = 1 / len (self .workers )
1384+ if sys .platform == "linux" :
1385+ # Linux is good with socket roundtrip latency, so we can use
1386+ # more fine-grained batches.
1387+ batch_frac /= 2
1388+ return int (self .size_in_queue * batch_frac )
1389+
13491390 def submit_to_workers (self , graph : Graph , sccs : list [SCC ] | None = None ) -> None :
13501391 if sccs is not None :
13511392 for scc in sccs :
13521393 heappush (self .scc_queue , (- scc .size_hint , self .queue_order , scc ))
1394+ self .size_in_queue += scc .size_hint
13531395 self .queue_order += 1
1396+ max_size_in_batch = self .max_batch_size ()
13541397 while self .scc_queue and self .free_workers :
13551398 idx = self .free_workers .pop ()
1356- _ , _ , scc = heappop ( self .scc_queue )
1399+ scc_batch = self .get_scc_batch ( max_size_in_batch )
13571400 import_errors = {
13581401 mod_id : self .errors .recorded [path ]
1402+ for scc in scc_batch
13591403 for mod_id in scc .mod_ids
13601404 if (path := graph [mod_id ].xpath ) in self .errors .recorded
13611405 }
13621406 t0 = time .time ()
13631407 send (
13641408 self .workers [idx ].conn ,
13651409 SccRequestMessage (
1366- scc_id = scc .id ,
1410+ scc_ids = [ scc .id for scc in scc_batch ] ,
13671411 import_errors = import_errors ,
13681412 mod_data = {
13691413 mod_id : (
@@ -1373,11 +1417,12 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
13731417 graph [mod_id ].suppressed_deps_opts (),
13741418 tree .raw_data if (tree := graph [mod_id ].tree ) else None ,
13751419 )
1420+ for scc in scc_batch
13761421 for mod_id in scc .mod_ids
13771422 },
13781423 ),
13791424 )
1380- self .add_stats (scc_send_time = time .time () - t0 )
1425+ self .add_stats (scc_requests_sent = 1 , scc_send_time = time .time () - t0 )
13811426
13821427 def wait_for_done (self , graph : Graph ) -> tuple [list [SCC ], bool , dict [str , ModuleResult ]]:
13831428 """Wait for a stale SCC processing to finish.
@@ -1414,13 +1459,12 @@ def wait_for_done_workers(
14141459 if not data .is_interface :
14151460 # Mark worker as free after it finished checking implementation.
14161461 self .free_workers .add (idx )
1417- scc_id = data .scc_id
14181462 if data .blocker is not None :
14191463 raise data .blocker
14201464 assert data .result is not None
14211465 results .update (data .result )
14221466 if data .is_interface :
1423- done_sccs .append ( self .scc_by_id [scc_id ])
1467+ done_sccs .extend ([ self .scc_by_id [scc_id ] for scc_id in data . scc_ids ])
14241468 self .add_stats (scc_wait_time = t1 - t0 , scc_receive_time = time .time () - t1 )
14251469 self .submit_to_workers (graph ) # advance after some workers are free.
14261470 return (
@@ -2727,7 +2771,7 @@ def new_state(
27272771 # import pkg.mod
27282772 if exist_removed_submodules (dependencies , manager ):
27292773 state .needs_parse = True # Same as above, the current state is stale anyway.
2730- state .size_hint = meta .size
2774+ state .size_hint = meta .size + MIN_SIZE_HINT
27312775 else :
27322776 # When doing a fine-grained cache load, pretend we only
27332777 # know about modules that have cache information and defer
@@ -3092,7 +3136,7 @@ def get_source(self) -> str:
30923136 self .parse_inline_configuration (source )
30933137 self .check_for_invalid_options ()
30943138
3095- self .size_hint = len (source )
3139+ self .size_hint = len (source ) + MIN_SIZE_HINT
30963140 self .time_spent_us += time_spent_us (t0 )
30973141 return source
30983142
@@ -3157,6 +3201,14 @@ def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None =
31573201 self .check_blockers ()
31583202
31593203 manager .ast_cache [self .id ] = (self .tree , self .early_errors )
3204+ assert self .tree is not None
3205+ if self .tree .raw_data is not None :
3206+ # Size of serialized tree is a better proxy for file complexity than
3207+ # file size, so we use that when possible. Note that we rely on lucky
3208+ # coincidence that serialized tree size has same order of magnitude as
3209+ # file size, so we don't need any normalization factor in situations
3210+ # where parsed and cached files are mixed.
3211+ self .size_hint = len (self .tree .raw_data .defs ) + MIN_SIZE_HINT
31603212 self .setup_errors ()
31613213
31623214 def setup_errors (self ) -> None :
@@ -5054,26 +5106,26 @@ def write(self, buf: WriteBuffer) -> None:
50545106
50555107class SccRequestMessage (IPCMessage ):
50565108 """
5057- A message representing a request to type check an SCC .
5109+ A message representing a request to type check a batch of SCCs .
50585110
5059- If scc_id is None , then it means that the coordinator requested a shutdown.
5111+ If scc_ids is empty , then it means that the coordinator requested a shutdown.
50605112 """
50615113
50625114 def __init__ (
50635115 self ,
50645116 * ,
5065- scc_id : int | None ,
5117+ scc_ids : list [ int ] ,
50665118 import_errors : dict [str , list [ErrorInfo ]],
50675119 mod_data : dict [str , tuple [bytes , FileRawData | None ]],
50685120 ) -> None :
5069- self .scc_id = scc_id
5121+ self .scc_ids = scc_ids
50705122 self .import_errors = import_errors
50715123 self .mod_data = mod_data
50725124
50735125 @classmethod
50745126 def read (cls , buf : ReadBuffer ) -> SccRequestMessage :
50755127 return SccRequestMessage (
5076- scc_id = read_int_opt (buf ),
5128+ scc_ids = read_int_list (buf ),
50775129 import_errors = {
50785130 read_str (buf ): [ErrorInfo .read (buf ) for _ in range (read_int_bare (buf ))]
50795131 for _ in range (read_int_bare (buf ))
@@ -5089,7 +5141,7 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage:
50895141
50905142 def write (self , buf : WriteBuffer ) -> None :
50915143 write_tag (buf , SCC_REQUEST_MESSAGE )
5092- write_int_opt (buf , self .scc_id )
5144+ write_int_list (buf , self .scc_ids )
50935145 write_int_bare (buf , len (self .import_errors ))
50945146 for path , errors in self .import_errors .items ():
50955147 write_str (buf , path )
@@ -5130,17 +5182,17 @@ def write(self, buf: WriteBuffer) -> None:
51305182
51315183class SccResponseMessage (IPCMessage ):
51325184 """
5133- A message representing a result of type checking an SCC .
5185+ A message representing a result of type checking a batch of SCCs .
51345186
51355187 Only one of `result` or `blocker` can be non-None. The latter means there was
5136- a blocking error while type checking the SCC . The `is_interface` flag indicates
5188+ a blocking error while type checking the SCCs . The `is_interface` flag indicates
51375189 whether this is a result for interface or implementation phase of type-checking.
51385190 """
51395191
51405192 def __init__ (
51415193 self ,
51425194 * ,
5143- scc_id : int ,
5195+ scc_ids : list [ int ] ,
51445196 is_interface : bool ,
51455197 result : dict [str , ModuleResult ] | None = None ,
51465198 blocker : CompileError | None = None ,
@@ -5149,26 +5201,26 @@ def __init__(
51495201 assert blocker is None
51505202 if blocker is not None :
51515203 assert result is None
5152- self .scc_id = scc_id
5204+ self .scc_ids = scc_ids
51535205 self .is_interface = is_interface
51545206 self .result = result
51555207 self .blocker = blocker
51565208
51575209 @classmethod
51585210 def read (cls , buf : ReadBuffer ) -> SccResponseMessage :
5159- scc_id = read_int (buf )
5211+ scc_ids = read_int_list (buf )
51605212 is_interface = read_bool (buf )
51615213 tag = read_tag (buf )
51625214 if tag == LITERAL_NONE :
51635215 return SccResponseMessage (
5164- scc_id = scc_id ,
5216+ scc_ids = scc_ids ,
51655217 is_interface = is_interface ,
51665218 blocker = CompileError (read_str_list (buf ), read_bool (buf ), read_str_opt (buf )),
51675219 )
51685220 else :
51695221 assert tag == DICT_STR_GEN
51705222 return SccResponseMessage (
5171- scc_id = scc_id ,
5223+ scc_ids = scc_ids ,
51725224 is_interface = is_interface ,
51735225 result = {
51745226 read_str_bare (buf ): ModuleResult .read (buf ) for _ in range (read_int_bare (buf ))
@@ -5177,7 +5229,7 @@ def read(cls, buf: ReadBuffer) -> SccResponseMessage:
51775229
51785230 def write (self , buf : WriteBuffer ) -> None :
51795231 write_tag (buf , SCC_RESPONSE_MESSAGE )
5180- write_int (buf , self .scc_id )
5232+ write_int_list (buf , self .scc_ids )
51815233 write_bool (buf , self .is_interface )
51825234 if self .result is None :
51835235 assert self .blocker is not None
0 commit comments