@@ -315,6 +315,13 @@ def share_model_across_processes(self) -> bool:
315315 https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
316316 return False
317317
318+ def model_copies (self ) -> int :
319+ """Returns the maximum number of model copies that should be loaded at one
320+ time. This only impacts model handlers that are using
321+ share_model_across_processes to share their model across processes instead
322+ of being loaded per process."""
323+ return 1
324+
318325 def override_metrics (self , metrics_namespace : str = '' ) -> bool :
319326 """Returns a boolean representing whether or not a model handler will
320327 override metrics reporting. If True, RunInference will not report any
@@ -795,6 +802,21 @@ def share_model_across_processes(self) -> bool:
795802 return self ._unkeyed .share_model_across_processes ()
796803 return True
797804
805+ def model_copies (self ) -> int :
806+ if self ._single_model :
807+ return self ._unkeyed .model_copies ()
808+ for mh in self ._id_to_mh_map .values ():
809+ if mh .model_copies () != 1 :
810+ raise ValueError (
811+ 'KeyedModelHandler cannot map records to multiple '
812+ 'models if one or more of its ModelHandlers '
813+ 'require multiple model copies (set via '
814+ 'model_copies). To fix, verify that each '
815+ 'ModelHandler is not set to load multiple copies of '
816+ 'its model.' )
817+
818+ return 1
819+
798820 def override_metrics (self , metrics_namespace : str = '' ) -> bool :
799821 if self ._single_model :
800822 return self ._unkeyed .override_metrics (metrics_namespace )
@@ -902,6 +924,9 @@ def should_skip_batching(self) -> bool:
902924 def share_model_across_processes (self ) -> bool :
903925 return self ._unkeyed .share_model_across_processes ()
904926
927+ def model_copies (self ) -> int :
928+ return self ._unkeyed .model_copies ()
929+
905930
906931class _PrebatchedModelHandler (Generic [ExampleT , PredictionT , ModelT ],
907932 ModelHandler [Sequence [ExampleT ],
@@ -952,6 +977,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
952977 def should_skip_batching (self ) -> bool :
953978 return True
954979
980+ def share_model_across_processes (self ) -> bool :
981+ return self ._base .share_model_across_processes ()
982+
983+ def model_copies (self ) -> int :
984+ return self ._base .model_copies ()
985+
955986 def get_postprocess_fns (self ) -> Iterable [Callable [[Any ], Any ]]:
956987 return self ._base .get_postprocess_fns ()
957988
@@ -1012,6 +1043,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
10121043 def should_skip_batching (self ) -> bool :
10131044 return self ._base .should_skip_batching ()
10141045
1046+ def share_model_across_processes (self ) -> bool :
1047+ return self ._base .share_model_across_processes ()
1048+
1049+ def model_copies (self ) -> int :
1050+ return self ._base .model_copies ()
1051+
10151052 def get_postprocess_fns (self ) -> Iterable [Callable [[Any ], Any ]]:
10161053 return self ._base .get_postprocess_fns ()
10171054
@@ -1071,6 +1108,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
10711108 def should_skip_batching (self ) -> bool :
10721109 return self ._base .should_skip_batching ()
10731110
1111+ def share_model_across_processes (self ) -> bool :
1112+ return self ._base .share_model_across_processes ()
1113+
1114+ def model_copies (self ) -> int :
1115+ return self ._base .model_copies ()
1116+
10741117 def get_postprocess_fns (self ) -> Iterable [Callable [[Any ], Any ]]:
10751118 return self ._base .get_postprocess_fns () + [self ._postprocess_fn ]
10761119
@@ -1378,6 +1421,45 @@ def update(
13781421 self ._inference_request_batch_byte_size .update (examples_byte_size )
13791422
13801423
1424+ class _ModelRoutingStrategy ():
1425+ """A class meant to sit in a shared location for mapping incoming batches to
1426+ different models. Currently only supports round-robin, but can be extended
1427+ to support other protocols if needed.
1428+ """
1429+ def __init__ (self ):
1430+ self ._cur_index = 0
1431+
1432+ def next_model_index (self , num_models ):
1433+ self ._cur_index = (self ._cur_index + 1 ) % num_models
1434+ return self ._cur_index
1435+
1436+
1437+ class _SharedModelWrapper ():
1438+ """A router class to map incoming calls to the correct model.
1439+
1440+ This allows us to round robin calls to models sitting in different
1441+ processes so that we can more efficiently use resources (e.g. GPUs).
1442+ """
1443+ def __init__ (self , models : List [Any ], model_tag : str ):
1444+ self .models = models
1445+ if len (models ) > 1 :
1446+ self .model_router = multi_process_shared .MultiProcessShared (
1447+ lambda : _ModelRoutingStrategy (),
1448+ tag = f'{ model_tag } _counter' ,
1449+ always_proxy = True ).acquire ()
1450+
1451+ def next_model (self ):
1452+ if len (self .models ) == 1 :
1453+ # Short circuit if there's no routing strategy needed in order to
1454+ # avoid the cross-process call
1455+ return self .models [0 ]
1456+
1457+ return self .models [self .model_router .next_model_index (len (self .models ))]
1458+
1459+ def all_models (self ):
1460+ return self .models
1461+
1462+
13811463class _RunInferenceDoFn (beam .DoFn , Generic [ExampleT , PredictionT ]):
13821464 def __init__ (
13831465 self ,
@@ -1408,16 +1490,19 @@ def __init__(
14081490 def _load_model (
14091491 self ,
14101492 side_input_model_path : Optional [Union [str ,
1411- List [KeyModelPathMapping ]]] = None ):
1493+ List [KeyModelPathMapping ]]] = None
1494+ ) -> _SharedModelWrapper :
14121495 def load ():
14131496 """Function for constructing shared LoadedModel."""
14141497 memory_before = _get_current_process_memory_in_bytes ()
14151498 start_time = _to_milliseconds (self ._clock .time_ns ())
14161499 if isinstance (side_input_model_path , str ):
14171500 self ._model_handler .update_model_path (side_input_model_path )
14181501 else :
1419- self ._model_handler .update_model_paths (
1420- self ._model , side_input_model_path )
1502+ if self ._model is not None :
1503+ models = self ._model .all_models ()
1504+ for m in models :
1505+ self ._model_handler .update_model_paths (m , side_input_model_path )
14211506 model = self ._model_handler .load_model ()
14221507 end_time = _to_milliseconds (self ._clock .time_ns ())
14231508 memory_after = _get_current_process_memory_in_bytes ()
@@ -1434,19 +1519,27 @@ def load():
14341519 if isinstance (side_input_model_path , str ) and side_input_model_path != '' :
14351520 model_tag = side_input_model_path
14361521 if self ._model_handler .share_model_across_processes ():
1437- model = multi_process_shared .MultiProcessShared (
1438- load , tag = model_tag , always_proxy = True ).acquire ()
1522+ models = []
1523+ for i in range (self ._model_handler .model_copies ()):
1524+ models .append (
1525+ multi_process_shared .MultiProcessShared (
1526+ load , tag = f'{ model_tag } { i } ' , always_proxy = True ).acquire ())
1527+ model_wrapper = _SharedModelWrapper (models , model_tag )
14391528 else :
14401529 model = self ._shared_model_handle .acquire (load , tag = model_tag )
1530+ model_wrapper = _SharedModelWrapper ([model ], model_tag )
14411531 # since shared_model_handle is shared across threads, the model path
14421532 # might not get updated in the model handler
14431533 # because we directly get cached weak ref model from shared cache, instead
14441534 # of calling load(). For sanity check, call update_model_path again.
14451535 if isinstance (side_input_model_path , str ):
14461536 self ._model_handler .update_model_path (side_input_model_path )
14471537 else :
1448- self ._model_handler .update_model_paths (self ._model , side_input_model_path )
1449- return model
1538+ if self ._model is not None :
1539+ models = self ._model .all_models ()
1540+ for m in models :
1541+ self ._model_handler .update_model_paths (m , side_input_model_path )
1542+ return model_wrapper
14501543
14511544 def get_metrics_collector (self , prefix : str = '' ):
14521545 """
@@ -1476,8 +1569,9 @@ def update_model(
14761569 def _run_inference (self , batch , inference_args ):
14771570 start_time = _to_microseconds (self ._clock .time_ns ())
14781571 try :
1572+ model = self ._model .next_model ()
14791573 result_generator = self ._model_handler .run_inference (
1480- batch , self . _model , inference_args )
1574+ batch , model , inference_args )
14811575 except BaseException as e :
14821576 if self ._metrics_collector :
14831577 self ._metrics_collector .failed_batches_counter .inc ()
0 commit comments