Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3c8a881

Browse files
authored
Add ability to load multiple copies of a model across processes (#31052)
* Add ability to load multiple copies of a model across processes * push changes I had locally not remotely * Lint * naming + lint * Changes from feedback
1 parent 3d3669e commit 3c8a881

8 files changed

Lines changed: 283 additions & 30 deletions

File tree

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,13 @@ def share_model_across_processes(self) -> bool:
315315
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
316316
return False
317317

318+
def model_copies(self) -> int:
319+
"""Returns the maximum number of model copies that should be loaded at one
320+
time. This only impacts model handlers that are using
321+
share_model_across_processes to share their model across processes instead
322+
of being loaded per process."""
323+
return 1
324+
318325
def override_metrics(self, metrics_namespace: str = '') -> bool:
319326
"""Returns a boolean representing whether or not a model handler will
320327
override metrics reporting. If True, RunInference will not report any
@@ -795,6 +802,21 @@ def share_model_across_processes(self) -> bool:
795802
return self._unkeyed.share_model_across_processes()
796803
return True
797804

805+
def model_copies(self) -> int:
806+
if self._single_model:
807+
return self._unkeyed.model_copies()
808+
for mh in self._id_to_mh_map.values():
809+
if mh.model_copies() != 1:
810+
raise ValueError(
811+
'KeyedModelHandler cannot map records to multiple '
812+
'models if one or more of its ModelHandlers '
813+
'require multiple model copies (set via '
814+
'model_copies). To fix, verify that each '
815+
'ModelHandler is not set to load multiple copies of '
816+
'its model.')
817+
818+
return 1
819+
798820
def override_metrics(self, metrics_namespace: str = '') -> bool:
799821
if self._single_model:
800822
return self._unkeyed.override_metrics(metrics_namespace)
@@ -902,6 +924,9 @@ def should_skip_batching(self) -> bool:
902924
def share_model_across_processes(self) -> bool:
903925
return self._unkeyed.share_model_across_processes()
904926

927+
def model_copies(self) -> int:
928+
return self._unkeyed.model_copies()
929+
905930

906931
class _PrebatchedModelHandler(Generic[ExampleT, PredictionT, ModelT],
907932
ModelHandler[Sequence[ExampleT],
@@ -952,6 +977,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
952977
def should_skip_batching(self) -> bool:
953978
return True
954979

980+
def share_model_across_processes(self) -> bool:
981+
return self._base.share_model_across_processes()
982+
983+
def model_copies(self) -> int:
984+
return self._base.model_copies()
985+
955986
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
956987
return self._base.get_postprocess_fns()
957988

@@ -1012,6 +1043,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
10121043
def should_skip_batching(self) -> bool:
10131044
return self._base.should_skip_batching()
10141045

1046+
def share_model_across_processes(self) -> bool:
1047+
return self._base.share_model_across_processes()
1048+
1049+
def model_copies(self) -> int:
1050+
return self._base.model_copies()
1051+
10151052
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
10161053
return self._base.get_postprocess_fns()
10171054

@@ -1071,6 +1108,12 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
10711108
def should_skip_batching(self) -> bool:
10721109
return self._base.should_skip_batching()
10731110

1111+
def share_model_across_processes(self) -> bool:
1112+
return self._base.share_model_across_processes()
1113+
1114+
def model_copies(self) -> int:
1115+
return self._base.model_copies()
1116+
10741117
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
10751118
return self._base.get_postprocess_fns() + [self._postprocess_fn]
10761119

@@ -1378,6 +1421,45 @@ def update(
13781421
self._inference_request_batch_byte_size.update(examples_byte_size)
13791422

13801423

1424+
class _ModelRoutingStrategy():
1425+
"""A class meant to sit in a shared location for mapping incoming batches to
1426+
different models. Currently only supports round-robin, but can be extended
1427+
to support other protocols if needed.
1428+
"""
1429+
def __init__(self):
1430+
self._cur_index = 0
1431+
1432+
def next_model_index(self, num_models):
1433+
self._cur_index = (self._cur_index + 1) % num_models
1434+
return self._cur_index
1435+
1436+
1437+
class _SharedModelWrapper():
1438+
"""A router class to map incoming calls to the correct model.
1439+
1440+
This allows us to round robin calls to models sitting in different
1441+
processes so that we can more efficiently use resources (e.g. GPUs).
1442+
"""
1443+
def __init__(self, models: List[Any], model_tag: str):
1444+
self.models = models
1445+
if len(models) > 1:
1446+
self.model_router = multi_process_shared.MultiProcessShared(
1447+
lambda: _ModelRoutingStrategy(),
1448+
tag=f'{model_tag}_counter',
1449+
always_proxy=True).acquire()
1450+
1451+
def next_model(self):
1452+
if len(self.models) == 1:
1453+
# Short circuit if there's no routing strategy needed in order to
1454+
# avoid the cross-process call
1455+
return self.models[0]
1456+
1457+
return self.models[self.model_router.next_model_index(len(self.models))]
1458+
1459+
def all_models(self):
1460+
return self.models
1461+
1462+
13811463
class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
13821464
def __init__(
13831465
self,
@@ -1408,16 +1490,19 @@ def __init__(
14081490
def _load_model(
14091491
self,
14101492
side_input_model_path: Optional[Union[str,
1411-
List[KeyModelPathMapping]]] = None):
1493+
List[KeyModelPathMapping]]] = None
1494+
) -> _SharedModelWrapper:
14121495
def load():
14131496
"""Function for constructing shared LoadedModel."""
14141497
memory_before = _get_current_process_memory_in_bytes()
14151498
start_time = _to_milliseconds(self._clock.time_ns())
14161499
if isinstance(side_input_model_path, str):
14171500
self._model_handler.update_model_path(side_input_model_path)
14181501
else:
1419-
self._model_handler.update_model_paths(
1420-
self._model, side_input_model_path)
1502+
if self._model is not None:
1503+
models = self._model.all_models()
1504+
for m in models:
1505+
self._model_handler.update_model_paths(m, side_input_model_path)
14211506
model = self._model_handler.load_model()
14221507
end_time = _to_milliseconds(self._clock.time_ns())
14231508
memory_after = _get_current_process_memory_in_bytes()
@@ -1434,19 +1519,27 @@ def load():
14341519
if isinstance(side_input_model_path, str) and side_input_model_path != '':
14351520
model_tag = side_input_model_path
14361521
if self._model_handler.share_model_across_processes():
1437-
model = multi_process_shared.MultiProcessShared(
1438-
load, tag=model_tag, always_proxy=True).acquire()
1522+
models = []
1523+
for i in range(self._model_handler.model_copies()):
1524+
models.append(
1525+
multi_process_shared.MultiProcessShared(
1526+
load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
1527+
model_wrapper = _SharedModelWrapper(models, model_tag)
14391528
else:
14401529
model = self._shared_model_handle.acquire(load, tag=model_tag)
1530+
model_wrapper = _SharedModelWrapper([model], model_tag)
14411531
# since shared_model_handle is shared across threads, the model path
14421532
# might not get updated in the model handler
14431533
# because we directly get cached weak ref model from shared cache, instead
14441534
# of calling load(). For sanity check, call update_model_path again.
14451535
if isinstance(side_input_model_path, str):
14461536
self._model_handler.update_model_path(side_input_model_path)
14471537
else:
1448-
self._model_handler.update_model_paths(self._model, side_input_model_path)
1449-
return model
1538+
if self._model is not None:
1539+
models = self._model.all_models()
1540+
for m in models:
1541+
self._model_handler.update_model_paths(m, side_input_model_path)
1542+
return model_wrapper
14501543

14511544
def get_metrics_collector(self, prefix: str = ''):
14521545
"""
@@ -1476,8 +1569,9 @@ def update_model(
14761569
def _run_inference(self, batch, inference_args):
14771570
start_time = _to_microseconds(self._clock.time_ns())
14781571
try:
1572+
model = self._model.next_model()
14791573
result_generator = self._model_handler.run_inference(
1480-
batch, self._model, inference_args)
1574+
batch, model, inference_args)
14811575
except BaseException as e:
14821576
if self._metrics_collector:
14831577
self._metrics_collector.failed_batches_counter.inc()

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def increment_state(self, amount: int):
6363
self._state += amount
6464

6565

66+
class FakeIncrementingModel:
67+
def __init__(self):
68+
self._state = 0
69+
70+
def predict(self, example: int) -> int:
71+
self._state += 1
72+
return self._state
73+
74+
6675
class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
6776
def __init__(
6877
self,
@@ -71,6 +80,8 @@ def __init__(
7180
max_batch_size=9999,
7281
multi_process_shared=False,
7382
state=None,
83+
incrementing=False,
84+
max_copies=1,
7485
num_bytes_per_element=None,
7586
**kwargs):
7687
self._fake_clock = clock
@@ -79,11 +90,16 @@ def __init__(
7990
self._env_vars = kwargs.get('env_vars', {})
8091
self._multi_process_shared = multi_process_shared
8192
self._state = state
93+
self._incrementing = incrementing
94+
self._max_copies = max_copies
8295
self._num_bytes_per_element = num_bytes_per_element
8396

8497
def load_model(self):
98+
assert (not self._incrementing or self._state is None)
8599
if self._fake_clock:
86100
self._fake_clock.current_time_ns += 500_000_000 # 500ms
101+
if self._incrementing:
102+
return FakeIncrementingModel()
87103
if self._state is not None:
88104
return FakeStatefulModel(self._state)
89105
return FakeModel()
@@ -116,6 +132,9 @@ def batch_elements_kwargs(self):
116132
def share_model_across_processes(self):
117133
return self._multi_process_shared
118134

135+
def model_copies(self):
136+
return self._max_copies
137+
119138
def get_num_bytes(self, batch: Sequence[int]) -> int:
120139
if self._num_bytes_per_element:
121140
return self._num_bytes_per_element * len(batch)
@@ -258,6 +277,58 @@ def test_run_inference_impl_simple_examples_multi_process_shared(self):
258277
FakeModelHandler(multi_process_shared=True))
259278
assert_that(actual, equal_to(expected), label='assert:inferences')
260279

280+
def test_run_inference_impl_simple_examples_multi_process_shared_multi_copy(
281+
self):
282+
with TestPipeline() as pipeline:
283+
examples = [1, 5, 3, 10]
284+
expected = [example + 1 for example in examples]
285+
pcoll = pipeline | 'start' >> beam.Create(examples)
286+
actual = pcoll | base.RunInference(
287+
FakeModelHandler(multi_process_shared=True, max_copies=4))
288+
assert_that(actual, equal_to(expected), label='assert:inferences')
289+
290+
def test_run_inference_impl_multi_process_shared_incrementing_multi_copy(
291+
self):
292+
with TestPipeline() as pipeline:
293+
examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
294+
expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
295+
pcoll = pipeline | 'start' >> beam.Create(examples)
296+
actual = pcoll | base.RunInference(
297+
FakeModelHandler(
298+
multi_process_shared=True,
299+
max_copies=4,
300+
incrementing=True,
301+
max_batch_size=1))
302+
assert_that(actual, equal_to(expected), label='assert:inferences')
303+
304+
def test_run_inference_impl_mps_nobatch_incrementing_multi_copy(self):
305+
with TestPipeline() as pipeline:
306+
examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
307+
expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
308+
batched_examples = [[example] for example in examples]
309+
pcoll = pipeline | 'start' >> beam.Create(batched_examples)
310+
actual = pcoll | base.RunInference(
311+
FakeModelHandler(
312+
multi_process_shared=True, max_copies=4,
313+
incrementing=True).with_no_batching())
314+
assert_that(actual, equal_to(expected), label='assert:inferences')
315+
316+
def test_run_inference_impl_keyed_mps_incrementing_multi_copy(self):
317+
with TestPipeline() as pipeline:
318+
examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
319+
keyed_examples = [('abc', example) for example in examples]
320+
expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
321+
keyed_expected = [('abc', val) for val in expected]
322+
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
323+
actual = pcoll | base.RunInference(
324+
base.KeyedModelHandler(
325+
FakeModelHandler(
326+
multi_process_shared=True,
327+
max_copies=4,
328+
incrementing=True,
329+
max_batch_size=1)))
330+
assert_that(actual, equal_to(keyed_expected), label='assert:inferences')
331+
261332
def test_run_inference_impl_with_keyed_examples(self):
262333
with TestPipeline() as pipeline:
263334
examples = [1, 5, 3, 10]

0 commit comments

Comments
 (0)