3636import uuid
3737from collections import OrderedDict
3838from collections import defaultdict
39+ from copy import deepcopy
40+ from dataclasses import dataclass
3941from typing import Any
4042from typing import Callable
4143from typing import Dict
@@ -122,6 +124,25 @@ def _to_microseconds(time_ns: int) -> int:
122124 return int (time_ns / _NANOSECOND_TO_MICROSECOND )
123125
124126
127+ @dataclass (frozen = True )
128+ class KeyModelPathMapping (Generic [KeyT ]):
129+ """
130+ Dataclass for mapping 1 or more keys to 1 model path. This is used in
131+ conjunction with a KeyedModelHandler with many model handlers to update
132+ a set of keys' model handlers with the new path. Given
133+ `KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path')`,
134+ all examples with keys `key1` or `key2` will have their corresponding model
135+ handler's update_model function called with 'updated/path'. For more
136+ information see the
137+ KeyedModelHandler documentation
138+ https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
139+ documentation and the website section on model updates
140+ https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
141+ """
142+ keys : List [KeyT ]
143+ update_path : str
144+
145+
125146class ModelHandler (Generic [ExampleT , PredictionT , ModelT ]):
126147 """Has the ability to load and apply an ML model."""
127148 def __init__ (self ):
@@ -191,7 +212,28 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
191212 'framework does not expect extra arguments on inferences.' )
192213
193214 def update_model_path (self , model_path : Optional [str ] = None ):
194- """Update the model paths produced by side inputs."""
215+ """
216+ Update the model path produced by side inputs. update_model_path should be
217+ used when a ModelHandler represents a single model, not multiple models.
218+ This will be true in most cases. For more information see the website
219+ section on model updates
220+ https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
221+ """
222+ pass
223+
224+ def update_model_paths (
225+ self ,
226+ model : ModelT ,
227+ model_paths : Optional [Union [str , List [KeyModelPathMapping ]]] = None ):
228+ """
229+ Update the model paths produced by side inputs. update_model_paths should
230+ be used when updating multiple models at once (e.g. when using a
231+ KeyedModelHandler that holds multiple models). For more information see
232+ the KeyedModelHandler documentation
233+ https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
234+ documentation and the website section on model updates
235+ https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
236+ """
195237 pass
196238
197239 def get_preprocess_fns (self ) -> Iterable [Callable [[Any ], Any ]]:
@@ -264,10 +306,17 @@ def __init__(
264306 allow unlimited models.
265307 """
266308 self ._max_models = max_models
309+ # Map keys to model handlers
267310 self ._mh_map : Dict [str , ModelHandler ] = mh_map
268- self ._proxy_map : Dict [str , str ] = {}
269- self ._tag_map : Dict [
270- str , multi_process_shared .MultiProcessShared ] = OrderedDict ()
311+ # Map keys to the last updated model path for that key
312+ self ._key_to_last_update : Dict [str , str ] = defaultdict (str )
313+ # Map key for a model to a unique tag that will persist for the life of
314+ # that model in memory. A new tag will be generated if a model is swapped
315+ # out of memory and reloaded.
316+ self ._tag_map : Dict [str , str ] = OrderedDict ()
317+ # Map a tag to a multiprocessshared model object for that tag. Each entry
318+ # of this map should last as long as the corresponding entry in _tag_map.
319+ self ._proxy_map : Dict [str , multi_process_shared .MultiProcessShared ] = {}
271320
272321 def load (self , key : str ) -> str :
273322 """
@@ -294,6 +343,7 @@ def load(self, key: str) -> str:
294343 tag_to_remove = self ._tag_map .popitem (last = False )[1 ]
295344 shared_handle , model_to_remove = self ._proxy_map [tag_to_remove ]
296345 shared_handle .release (model_to_remove )
346+ del self ._proxy_map [tag_to_remove ]
297347
298348 # Load the new model
299349 shared_handle = multi_process_shared .MultiProcessShared (
@@ -316,6 +366,32 @@ def increment_max_models(self, increment: int):
316366 " models mode)." )
317367 self ._max_models += increment
318368
369+ def update_model_handler (self , key : str , model_path : str , previous_key : str ):
370+ """
371+ Updates the model path of this model handler and removes it from memory so
372+ that it can be reloaded with the updated path. No-ops if no model update
373+ needs to be applied.
374+ Args:
375+ key: the key associated with the model we'd like to update.
376+ model_path: the new path to the model we'd like to load.
377+ previous_key: the key that is associated with the old version of this
378+ model. This will often be the same as the current key, but sometimes
379+ we will want to keep both the old and new models to serve different
380+ cohorts. In that case, the keys should be different.
381+ """
382+ if self ._key_to_last_update [key ] == model_path :
383+ return
384+ self ._key_to_last_update [key ] = model_path
385+ if key not in self ._mh_map :
386+ self ._mh_map [key ] = deepcopy (self ._mh_map [previous_key ])
387+ self ._mh_map [key ].update_model_path (model_path )
388+ if key in self ._tag_map :
389+ tag_to_remove = self ._tag_map [key ]
390+ shared_handle , model_to_remove = self ._proxy_map [tag_to_remove ]
391+ shared_handle .release (model_to_remove )
392+ del self ._tag_map [key ]
393+ del self ._proxy_map [tag_to_remove ]
394+
319395
320396# Use a dataclass instead of named tuple because NamedTuples and generics don't
321397# mix well across the board for all versions:
@@ -359,6 +435,31 @@ def __init__(
359435 at the same time; be careful not to load too many large models or your
360436 pipeline may cause Out of Memory exceptions.
361437
438+ KeyedModelHandlers support Automatic Model Refresh to update your model
439+ to a newer version without stopping your streaming pipeline. For an
440+ overview of this feature, see
441+ https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
442+
443+
444+ To use this feature with a KeyedModelHandler that has many models per key,
445+ you can pass in a list of KeyModelPathMapping objects to define your new
446+ model paths. For example, passing in the side input of
447+
448+ [KeyModelPathMapping(keys=['k1', 'k2'], update_path='update/path/1'),
449+ KeyModelPathMapping(keys=['k3'], update_path='update/path/2')]
450+
451+ will update the model corresponding to keys 'k1' and 'k2' with path
452+ 'update/path/1' and the model corresponding to 'k3' with 'update/path/2'.
453+ In order to do a side input update: (1) all restrictions mentioned in
454+ https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
455+ must be met, (2) all update_paths must be non-empty, even if they are not
456+ being updated from their original values, and (3) the set of keys
457+ originally defined cannot change. This means that if originally you have
458+ defined model handlers for 'key1', 'key2', and 'key3', all 3 of those keys
459+ must appear in your list of KeyModelPathMappings exactly once. No
460+ additional keys can be added.
461+
462+
362463 Args:
363464 unkeyed: Either (a) an implementation of ModelHandler that does not
364465 require keys or (b) a list of KeyMhMappings mapping lists of keys to
@@ -512,6 +613,75 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
512613 for mh in self ._id_to_mh_map .values ():
513614 mh .validate_inference_args (inference_args )
514615
616+ def update_model_paths (
617+ self ,
618+ model : Union [ModelT , _ModelManager ],
619+ model_paths : List [KeyModelPathMapping [KeyT ]] = None ):
620+ # When there are many models, the keyed model handler is responsible for
621+ # reorganizing the model handlers into cohorts and telling the model
622+ # manager to update every cohort's associated model handler. The model
623+ # manager is responsible for performing the updates and tracking which
624+ # updates have already been applied.
625+ if model_paths is None or len (model_paths ) == 0 or model is None :
626+ return
627+ if self ._single_model :
628+ raise RuntimeError (
629+ 'Invalid model update: sent many model paths to '
630+ 'update, but KeyedModelHandler is wrapping a single '
631+ 'model.' )
632+ # Map cohort ids to a dictionary mapping new model paths to the keys that
633+ # were originally in that cohort. We will use this to construct our new
634+ # cohorts.
635+ # cohort_path_mapping will be structured as follows:
636+ # {
637+ # original_cohort_id: {
638+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
639+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
640+ # }
641+ # }
642+ cohort_path_mapping : Dict [KeyT , Dict [str , List [KeyT ]]] = {}
643+ seen_keys = set ()
644+ for mp in model_paths :
645+ keys = mp .keys
646+ update_path = mp .update_path
647+ if len (update_path ) == 0 :
648+ raise ValueError (f'Invalid model update, path for { keys } is empty' )
649+ for key in keys :
650+ if key in seen_keys :
651+ raise ValueError (
652+ f'Invalid model update: { key } appears in multiple '
653+ 'update lists. A single model update must provide exactly one '
654+ 'updated path per key.' )
655+ seen_keys .add (key )
656+ if key not in self ._key_to_id_map :
657+ raise ValueError (
658+ f'Invalid model update: { key } appears in '
659+ 'update, but not in the original configuration.' )
660+ cohort_id = self ._key_to_id_map [key ]
661+ if cohort_id not in cohort_path_mapping :
662+ cohort_path_mapping [cohort_id ] = defaultdict (list )
663+ cohort_path_mapping [cohort_id ][update_path ].append (key )
664+ for key in self ._key_to_id_map :
665+ if key not in seen_keys :
666+ raise ValueError (
667+ f'Invalid model update: { key } appears in the '
668+ 'original configuration, but not the update.' )
669+
670+ # We now have our new set of cohorts. For each one, update our local model
671+ # handler configuration and send the results to the ModelManager
672+ for old_cohort_id , path_key_mapping in cohort_path_mapping .items ():
673+ for updated_path , keys in path_key_mapping .items ():
674+ cohort_id = old_cohort_id
675+ if old_cohort_id not in keys :
676+ # Create new cohort
677+ cohort_id = keys [0 ]
678+ for key in keys :
679+ self ._key_to_id_map [key ] = cohort_id
680+ mh = self ._id_to_mh_map [old_cohort_id ]
681+ self ._id_to_mh_map [cohort_id ] = deepcopy (mh )
682+ self ._id_to_mh_map [cohort_id ].update_model_path (updated_path )
683+ model .update_model_handler (cohort_id , updated_path , old_cohort_id )
684+
515685 def update_model_path (self , model_path : Optional [str ] = None ):
516686 if self ._single_model :
517687 return self ._unkeyed .update_model_path (model_path = model_path )
@@ -1046,12 +1216,19 @@ def __init__(
10461216 self ._side_input_path = None
10471217 self ._model_tag = model_tag
10481218
1049- def _load_model (self , side_input_model_path : Optional [str ] = None ):
1219+ def _load_model (
1220+ self ,
1221+ side_input_model_path : Optional [Union [str ,
1222+ List [KeyModelPathMapping ]]] = None ):
10501223 def load ():
10511224 """Function for constructing shared LoadedModel."""
10521225 memory_before = _get_current_process_memory_in_bytes ()
10531226 start_time = _to_milliseconds (self ._clock .time_ns ())
1054- self ._model_handler .update_model_path (side_input_model_path )
1227+ if isinstance (side_input_model_path , str ):
1228+ self ._model_handler .update_model_path (side_input_model_path )
1229+ else :
1230+ self ._model_handler .update_model_paths (
1231+ self ._model , side_input_model_path )
10551232 model = self ._model_handler .load_model ()
10561233 end_time = _to_milliseconds (self ._clock .time_ns ())
10571234 memory_after = _get_current_process_memory_in_bytes ()
@@ -1063,18 +1240,22 @@ def load():
10631240
10641241 # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
10651242 # model.
1243+ model_tag = self ._model_tag
1244+ if isinstance (side_input_model_path , str ) and side_input_model_path != '' :
1245+ model_tag = side_input_model_path
10661246 if self ._model_handler .share_model_across_processes ():
10671247 model = multi_process_shared .MultiProcessShared (
1068- load , tag = side_input_model_path or self ._model_tag ,
1069- always_proxy = True ).acquire ()
1248+ load , tag = model_tag , always_proxy = True ).acquire ()
10701249 else :
1071- model = self ._shared_model_handle .acquire (
1072- load , tag = side_input_model_path or self ._model_tag )
1250+ model = self ._shared_model_handle .acquire (load , tag = model_tag )
10731251 # since shared_model_handle is shared across threads, the model path
10741252 # might not get updated in the model handler
10751253 # because we directly get cached weak ref model from shared cache, instead
10761254 # of calling load(). For sanity check, call update_model_path again.
1077- self ._model_handler .update_model_path (side_input_model_path )
1255+ if isinstance (side_input_model_path , str ):
1256+ self ._model_handler .update_model_path (side_input_model_path )
1257+ else :
1258+ self ._model_handler .update_model_paths (self ._model , side_input_model_path )
10781259 return model
10791260
10801261 def get_metrics_collector (self , prefix : str = '' ):
@@ -1094,7 +1275,10 @@ def setup(self):
10941275 if not self ._enable_side_input_loading :
10951276 self ._model = self ._load_model ()
10961277
1097- def update_model (self , side_input_model_path : Optional [str ] = None ):
1278+ def update_model (
1279+ self ,
1280+ side_input_model_path : Optional [Union [str ,
1281+ List [KeyModelPathMapping ]]] = None ):
10981282 self ._model = self ._load_model (side_input_model_path = side_input_model_path )
10991283
11001284 def _run_inference (self , batch , inference_args ):
@@ -1116,25 +1300,36 @@ def _run_inference(self, batch, inference_args):
11161300 return predictions
11171301
11181302 def process (
1119- self , batch , inference_args , si_model_metadata : Optional [ModelMetadata ]):
1303+ self ,
1304+ batch ,
1305+ inference_args ,
1306+ si_model_metadata : Optional [Union [ModelMetadata ,
1307+ List [ModelMetadata ],
1308+ List [KeyModelPathMapping ]]]):
11201309 """
11211310 When side input is enabled:
11221311 The method checks if the side input model has been updated, and if so,
11231312 updates the model and runs inference on the batch of data. If the
11241313 side input is empty or the model has not been updated, the method
11251314 simply runs inference on the batch of data.
11261315 """
1127- if si_model_metadata :
1128- if isinstance (si_model_metadata , beam .pvalue .EmptySideInput ):
1129- self .update_model (side_input_model_path = None )
1316+ if not si_model_metadata :
1317+ return self ._run_inference (batch , inference_args )
1318+
1319+ if isinstance (si_model_metadata , beam .pvalue .EmptySideInput ):
1320+ self .update_model (side_input_model_path = None )
1321+ elif isinstance (si_model_metadata , List ) and hasattr (si_model_metadata [0 ],
1322+ 'keys' ):
1323+ # TODO(https://github.com/apache/beam/issues/27628): Update metrics here
1324+ self .update_model (si_model_metadata )
1325+ elif self ._side_input_path != si_model_metadata .model_id :
1326+ self ._side_input_path = si_model_metadata .model_id
1327+ self ._metrics_collector = self .get_metrics_collector (
1328+ prefix = si_model_metadata .model_name )
1329+ with threading .Lock ():
1330+ self .update_model (si_model_metadata .model_id )
11301331 return self ._run_inference (batch , inference_args )
1131- elif self ._side_input_path != si_model_metadata .model_id :
1132- self ._side_input_path = si_model_metadata .model_id
1133- self ._metrics_collector = self .get_metrics_collector (
1134- prefix = si_model_metadata .model_name )
1135- with threading .Lock ():
1136- self .update_model (si_model_metadata .model_id )
1137- return self ._run_inference (batch , inference_args )
1332+
11381333 return self ._run_inference (batch , inference_args )
11391334
11401335 def finish_bundle (self ):
0 commit comments