Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3b95d14

Browse files
Per key model updates (#28161)
* WIP: per key model updates * Still wip, but semi-complete * Things are working, may need cleanup * Fix lint issues * Doc improvements * Apply suggestions from code review Co-authored-by: Anand Inguva <[email protected]> * Feedback * fix locking --------- Co-authored-by: Anand Inguva <[email protected]>
1 parent 2a87135 commit 3b95d14

2 files changed

Lines changed: 454 additions & 25 deletions

File tree

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 218 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import uuid
3737
from collections import OrderedDict
3838
from collections import defaultdict
39+
from copy import deepcopy
40+
from dataclasses import dataclass
3941
from typing import Any
4042
from typing import Callable
4143
from typing import Dict
@@ -122,6 +124,25 @@ def _to_microseconds(time_ns: int) -> int:
122124
return int(time_ns / _NANOSECOND_TO_MICROSECOND)
123125

124126

127+
@dataclass(frozen=True)
128+
class KeyModelPathMapping(Generic[KeyT]):
129+
"""
130+
Dataclass for mapping 1 or more keys to 1 model path. This is used in
131+
conjunction with a KeyedModelHandler with many model handlers to update
132+
a set of keys' model handlers with the new path. Given
133+
`KeyModelPathMapping(keys: ['key1', 'key2'], update_path: 'updated/path')`,
134+
all examples with keys `key1` or `key2` will have their corresponding model
135+
handler's update_model function called with 'updated/path'. For more
136+
information see the
137+
KeyedModelHandler documentation
138+
https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
139+
documentation and the website section on model updates
140+
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
141+
"""
142+
keys: List[KeyT]
143+
update_path: str
144+
145+
125146
class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
126147
"""Has the ability to load and apply an ML model."""
127148
def __init__(self):
@@ -191,7 +212,28 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
191212
'framework does not expect extra arguments on inferences.')
192213

193214
def update_model_path(self, model_path: Optional[str] = None):
194-
"""Update the model paths produced by side inputs."""
215+
"""
216+
Update the model path produced by side inputs. update_model_path should be
217+
used when a ModelHandler represents a single model, not multiple models.
218+
This will be true in most cases. For more information see the website
219+
section on model updates
220+
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
221+
"""
222+
pass
223+
224+
def update_model_paths(
225+
self,
226+
model: ModelT,
227+
model_paths: Optional[Union[str, List[KeyModelPathMapping]]] = None):
228+
"""
229+
Update the model paths produced by side inputs. update_model_paths should
230+
be used when updating multiple models at once (e.g. when using a
231+
KeyedModelHandler that holds multiple models). For more information see
232+
the KeyedModelHandler documentation
233+
https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler
234+
documentation and the website section on model updates
235+
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
236+
"""
195237
pass
196238

197239
def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
@@ -264,10 +306,17 @@ def __init__(
264306
allow unlimited models.
265307
"""
266308
self._max_models = max_models
309+
# Map keys to model handlers
267310
self._mh_map: Dict[str, ModelHandler] = mh_map
268-
self._proxy_map: Dict[str, str] = {}
269-
self._tag_map: Dict[
270-
str, multi_process_shared.MultiProcessShared] = OrderedDict()
311+
# Map keys to the last updated model path for that key
312+
self._key_to_last_update: Dict[str, str] = defaultdict(str)
313+
# Map key for a model to a unique tag that will persist for the life of
314+
# that model in memory. A new tag will be generated if a model is swapped
315+
# out of memory and reloaded.
316+
self._tag_map: Dict[str, str] = OrderedDict()
317+
# Map a tag to a multiprocessshared model object for that tag. Each entry
318+
# of this map should last as long as the corresponding entry in _tag_map.
319+
self._proxy_map: Dict[str, multi_process_shared.MultiProcessShared] = {}
271320

272321
def load(self, key: str) -> str:
273322
"""
@@ -294,6 +343,7 @@ def load(self, key: str) -> str:
294343
tag_to_remove = self._tag_map.popitem(last=False)[1]
295344
shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
296345
shared_handle.release(model_to_remove)
346+
del self._proxy_map[tag_to_remove]
297347

298348
# Load the new model
299349
shared_handle = multi_process_shared.MultiProcessShared(
@@ -316,6 +366,32 @@ def increment_max_models(self, increment: int):
316366
" models mode).")
317367
self._max_models += increment
318368

369+
def update_model_handler(self, key: str, model_path: str, previous_key: str):
370+
"""
371+
Updates the model path of this model handler and removes it from memory so
372+
that it can be reloaded with the updated path. No-ops if no model update
373+
needs to be applied.
374+
Args:
375+
key: the key associated with the model we'd like to update.
376+
model_path: the new path to the model we'd like to load.
377+
previous_key: the key that is associated with the old version of this
378+
model. This will often be the same as the current key, but sometimes
379+
we will want to keep both the old and new models to serve different
380+
cohorts. In that case, the keys should be different.
381+
"""
382+
if self._key_to_last_update[key] == model_path:
383+
return
384+
self._key_to_last_update[key] = model_path
385+
if key not in self._mh_map:
386+
self._mh_map[key] = deepcopy(self._mh_map[previous_key])
387+
self._mh_map[key].update_model_path(model_path)
388+
if key in self._tag_map:
389+
tag_to_remove = self._tag_map[key]
390+
shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
391+
shared_handle.release(model_to_remove)
392+
del self._tag_map[key]
393+
del self._proxy_map[tag_to_remove]
394+
319395

320396
# Use a dataclass instead of named tuple because NamedTuples and generics don't
321397
# mix well across the board for all versions:
@@ -359,6 +435,31 @@ def __init__(
359435
at the same time; be careful not to load too many large models or your
360436
pipeline may cause Out of Memory exceptions.
361437
438+
KeyedModelHandlers support Automatic Model Refresh to update your model
439+
to a newer version without stopping your streaming pipeline. For an
440+
overview of this feature, see
441+
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
442+
443+
444+
To use this feature with a KeyedModelHandler that has many models per key,
445+
you can pass in a list of KeyModelPathMapping objects to define your new
446+
model paths. For example, passing in the side input of
447+
448+
[KeyModelPathMapping(keys=['k1', 'k2'], update_path='update/path/1'),
449+
KeyModelPathMapping(keys=['k3'], update_path='update/path/2')]
450+
451+
will update the model corresponding to keys 'k1' and 'k2' with path
452+
'update/path/1' and the model corresponding to 'k3' with 'update/path/2'.
453+
In order to do a side input update: (1) all restrictions mentioned in
454+
https://beam.apache.org/documentation/sdks/python-machine-learning/#automatic-model-refresh
455+
must be met, (2) all update_paths must be non-empty, even if they are not
456+
being updated from their original values, and (3) the set of keys
457+
originally defined cannot change. This means that if originally you have
458+
defined model handlers for 'key1', 'key2', and 'key3', all 3 of those keys
459+
must appear in your list of KeyModelPathMappings exactly once. No
460+
additional keys can be added.
461+
462+
362463
Args:
363464
unkeyed: Either (a) an implementation of ModelHandler that does not
364465
require keys or (b) a list of KeyMhMappings mapping lists of keys to
@@ -512,6 +613,75 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
512613
for mh in self._id_to_mh_map.values():
513614
mh.validate_inference_args(inference_args)
514615

616+
def update_model_paths(
617+
self,
618+
model: Union[ModelT, _ModelManager],
619+
model_paths: List[KeyModelPathMapping[KeyT]] = None):
620+
# When there are many models, the keyed model handler is responsible for
621+
# reorganizing the model handlers into cohorts and telling the model
622+
# manager to update every cohort's associated model handler. The model
623+
# manager is responsible for performing the updates and tracking which
624+
# updates have already been applied.
625+
if model_paths is None or len(model_paths) == 0 or model is None:
626+
return
627+
if self._single_model:
628+
raise RuntimeError(
629+
'Invalid model update: sent many model paths to '
630+
'update, but KeyedModelHandler is wrapping a single '
631+
'model.')
632+
# Map cohort ids to a dictionary mapping new model paths to the keys that
633+
# were originally in that cohort. We will use this to construct our new
634+
# cohorts.
635+
# cohort_path_mapping will be structured as follows:
636+
# {
637+
# original_cohort_id: {
638+
# 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
639+
# 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
640+
# }
641+
# }
642+
cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
643+
seen_keys = set()
644+
for mp in model_paths:
645+
keys = mp.keys
646+
update_path = mp.update_path
647+
if len(update_path) == 0:
648+
raise ValueError(f'Invalid model update, path for {keys} is empty')
649+
for key in keys:
650+
if key in seen_keys:
651+
raise ValueError(
652+
f'Invalid model update: {key} appears in multiple '
653+
'update lists. A single model update must provide exactly one '
654+
'updated path per key.')
655+
seen_keys.add(key)
656+
if key not in self._key_to_id_map:
657+
raise ValueError(
658+
f'Invalid model update: {key} appears in '
659+
'update, but not in the original configuration.')
660+
cohort_id = self._key_to_id_map[key]
661+
if cohort_id not in cohort_path_mapping:
662+
cohort_path_mapping[cohort_id] = defaultdict(list)
663+
cohort_path_mapping[cohort_id][update_path].append(key)
664+
for key in self._key_to_id_map:
665+
if key not in seen_keys:
666+
raise ValueError(
667+
f'Invalid model update: {key} appears in the '
668+
'original configuration, but not the update.')
669+
670+
# We now have our new set of cohorts. For each one, update our local model
671+
# handler configuration and send the results to the ModelManager
672+
for old_cohort_id, path_key_mapping in cohort_path_mapping.items():
673+
for updated_path, keys in path_key_mapping.items():
674+
cohort_id = old_cohort_id
675+
if old_cohort_id not in keys:
676+
# Create new cohort
677+
cohort_id = keys[0]
678+
for key in keys:
679+
self._key_to_id_map[key] = cohort_id
680+
mh = self._id_to_mh_map[old_cohort_id]
681+
self._id_to_mh_map[cohort_id] = deepcopy(mh)
682+
self._id_to_mh_map[cohort_id].update_model_path(updated_path)
683+
model.update_model_handler(cohort_id, updated_path, old_cohort_id)
684+
515685
def update_model_path(self, model_path: Optional[str] = None):
516686
if self._single_model:
517687
return self._unkeyed.update_model_path(model_path=model_path)
@@ -1046,12 +1216,19 @@ def __init__(
10461216
self._side_input_path = None
10471217
self._model_tag = model_tag
10481218

1049-
def _load_model(self, side_input_model_path: Optional[str] = None):
1219+
def _load_model(
1220+
self,
1221+
side_input_model_path: Optional[Union[str,
1222+
List[KeyModelPathMapping]]] = None):
10501223
def load():
10511224
"""Function for constructing shared LoadedModel."""
10521225
memory_before = _get_current_process_memory_in_bytes()
10531226
start_time = _to_milliseconds(self._clock.time_ns())
1054-
self._model_handler.update_model_path(side_input_model_path)
1227+
if isinstance(side_input_model_path, str):
1228+
self._model_handler.update_model_path(side_input_model_path)
1229+
else:
1230+
self._model_handler.update_model_paths(
1231+
self._model, side_input_model_path)
10551232
model = self._model_handler.load_model()
10561233
end_time = _to_milliseconds(self._clock.time_ns())
10571234
memory_after = _get_current_process_memory_in_bytes()
@@ -1063,18 +1240,22 @@ def load():
10631240

10641241
# TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
10651242
# model.
1243+
model_tag = self._model_tag
1244+
if isinstance(side_input_model_path, str) and side_input_model_path != '':
1245+
model_tag = side_input_model_path
10661246
if self._model_handler.share_model_across_processes():
10671247
model = multi_process_shared.MultiProcessShared(
1068-
load, tag=side_input_model_path or self._model_tag,
1069-
always_proxy=True).acquire()
1248+
load, tag=model_tag, always_proxy=True).acquire()
10701249
else:
1071-
model = self._shared_model_handle.acquire(
1072-
load, tag=side_input_model_path or self._model_tag)
1250+
model = self._shared_model_handle.acquire(load, tag=model_tag)
10731251
# since shared_model_handle is shared across threads, the model path
10741252
# might not get updated in the model handler
10751253
# because we directly get cached weak ref model from shared cache, instead
10761254
# of calling load(). For sanity check, call update_model_path again.
1077-
self._model_handler.update_model_path(side_input_model_path)
1255+
if isinstance(side_input_model_path, str):
1256+
self._model_handler.update_model_path(side_input_model_path)
1257+
else:
1258+
self._model_handler.update_model_paths(self._model, side_input_model_path)
10781259
return model
10791260

10801261
def get_metrics_collector(self, prefix: str = ''):
@@ -1094,7 +1275,10 @@ def setup(self):
10941275
if not self._enable_side_input_loading:
10951276
self._model = self._load_model()
10961277

1097-
def update_model(self, side_input_model_path: Optional[str] = None):
1278+
def update_model(
1279+
self,
1280+
side_input_model_path: Optional[Union[str,
1281+
List[KeyModelPathMapping]]] = None):
10981282
self._model = self._load_model(side_input_model_path=side_input_model_path)
10991283

11001284
def _run_inference(self, batch, inference_args):
@@ -1116,25 +1300,36 @@ def _run_inference(self, batch, inference_args):
11161300
return predictions
11171301

11181302
def process(
1119-
self, batch, inference_args, si_model_metadata: Optional[ModelMetadata]):
1303+
self,
1304+
batch,
1305+
inference_args,
1306+
si_model_metadata: Optional[Union[ModelMetadata,
1307+
List[ModelMetadata],
1308+
List[KeyModelPathMapping]]]):
11201309
"""
11211310
When side input is enabled:
11221311
The method checks if the side input model has been updated, and if so,
11231312
updates the model and runs inference on the batch of data. If the
11241313
side input is empty or the model has not been updated, the method
11251314
simply runs inference on the batch of data.
11261315
"""
1127-
if si_model_metadata:
1128-
if isinstance(si_model_metadata, beam.pvalue.EmptySideInput):
1129-
self.update_model(side_input_model_path=None)
1316+
if not si_model_metadata:
1317+
return self._run_inference(batch, inference_args)
1318+
1319+
if isinstance(si_model_metadata, beam.pvalue.EmptySideInput):
1320+
self.update_model(side_input_model_path=None)
1321+
elif isinstance(si_model_metadata, List) and hasattr(si_model_metadata[0],
1322+
'keys'):
1323+
# TODO(https://github.com/apache/beam/issues/27628): Update metrics here
1324+
self.update_model(si_model_metadata)
1325+
elif self._side_input_path != si_model_metadata.model_id:
1326+
self._side_input_path = si_model_metadata.model_id
1327+
self._metrics_collector = self.get_metrics_collector(
1328+
prefix=si_model_metadata.model_name)
1329+
with threading.Lock():
1330+
self.update_model(si_model_metadata.model_id)
11301331
return self._run_inference(batch, inference_args)
1131-
elif self._side_input_path != si_model_metadata.model_id:
1132-
self._side_input_path = si_model_metadata.model_id
1133-
self._metrics_collector = self.get_metrics_collector(
1134-
prefix=si_model_metadata.model_name)
1135-
with threading.Lock():
1136-
self.update_model(si_model_metadata.model_id)
1137-
return self._run_inference(batch, inference_args)
1332+
11381333
return self._run_inference(batch, inference_args)
11391334

11401335
def finish_bundle(self):

0 commit comments

Comments
 (0)