Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6a99cfe

Browse files
authored
Allow model handlers to request multi_process_shared model (#26688)
* Allow model handlers to request multi_process_shared model * Remove resolved todo
1 parent 4eb8956 commit 6a99cfe

2 files changed

Lines changed: 204 additions & 5 deletions

File tree

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import sys
3434
import threading
3535
import time
36+
import uuid
3637
from typing import Any
3738
from typing import Callable
3839
from typing import Dict
@@ -47,6 +48,7 @@
4748
from typing import Union
4849

4950
import apache_beam as beam
51+
from apache_beam.utils import multi_process_shared
5052
from apache_beam.utils import shared
5153

5254
try:
@@ -230,6 +232,15 @@ def with_postprocess_fn(
230232
inference result in order from first applied to last applied."""
231233
return _PostProcessingModelHandler(self, fn)
232234

235+
def share_model_across_processes(self) -> bool:
236+
"""Returns a boolean representing whether or not a model should
237+
be shared across multiple processes instead of being loaded per process.
238+
This is primary useful for large models that can't fit multiple copies in
239+
memory. Multi-process support may vary by runner, but this will fallback to
240+
loading per process as necessary. See
241+
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
242+
return False
243+
233244

234245
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
235246
ModelHandler[Tuple[KeyT, ExampleT],
@@ -293,6 +304,9 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
293304
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
294305
return self._unkeyed.get_postprocess_fns()
295306

307+
def share_model_across_processes(self) -> bool:
308+
return self._unkeyed.share_model_across_processes()
309+
296310

297311
class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
298312
ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -382,6 +396,9 @@ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
382396
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
383397
return self._unkeyed.get_postprocess_fns()
384398

399+
def share_model_across_processes(self) -> bool:
400+
return self._unkeyed.share_model_across_processes()
401+
385402

386403
class _PreProcessingModelHandler(Generic[ExampleT,
387404
PredictionT,
@@ -541,6 +558,9 @@ def __init__(
541558
self._with_exception_handling = False
542559
self._watch_model_pattern = watch_model_pattern
543560
self._kwargs = kwargs
561+
# Generate a random tag to use for shared.py and multi_process_shared.py to
562+
# allow us to effectively disambiguate in multi-model settings.
563+
self._model_tag = uuid.uuid4().hex
544564

545565
def _get_model_metadata_pcoll(self, pipeline):
546566
# avoid circular imports.
@@ -626,7 +646,8 @@ def expand(
626646
self._model_handler,
627647
self._clock,
628648
self._metrics_namespace,
629-
self._enable_side_input_loading),
649+
self._enable_side_input_loading,
650+
self._model_tag),
630651
self._inference_args,
631652
beam.pvalue.AsSingleton(
632653
self._model_metadata_pcoll,
@@ -783,7 +804,8 @@ def __init__(
783804
model_handler: ModelHandler[ExampleT, PredictionT, Any],
784805
clock,
785806
metrics_namespace,
786-
enable_side_input_loading: bool = False):
807+
enable_side_input_loading: bool = False,
808+
model_tag: str = "RunInference"):
787809
"""A DoFn implementation generic to frameworks.
788810
789811
Args:
@@ -792,6 +814,7 @@ def __init__(
792814
metrics_namespace: Namespace of the transform to collect metrics.
793815
enable_side_input_loading: Bool to indicate if model updates
794816
with side inputs.
817+
model_tag: Tag to use to disambiguate models in multi-model settings.
795818
"""
796819
self._model_handler = model_handler
797820
self._shared_model_handle = shared.Shared()
@@ -800,6 +823,7 @@ def __init__(
800823
self._metrics_namespace = metrics_namespace
801824
self._enable_side_input_loading = enable_side_input_loading
802825
self._side_input_path = None
826+
self._model_tag = model_tag
803827

804828
def _load_model(self, side_input_model_path: Optional[str] = None):
805829
def load():
@@ -818,7 +842,12 @@ def load():
818842

819843
# TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
820844
# model.
821-
model = self._shared_model_handle.acquire(load, tag=side_input_model_path)
845+
if self._model_handler.share_model_across_processes():
846+
model = multi_process_shared.MultiProcessShared(
847+
load, tag=side_input_model_path or self._model_tag).acquire()
848+
else:
849+
model = self._shared_model_handle.acquire(
850+
load, tag=side_input_model_path or self._model_tag)
822851
# since shared_model_handle is shared across threads, the model path
823852
# might not get updated in the model handler
824853
# because we directly get cached weak ref model from shared cache, instead

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,17 @@ def predict(self, example: int) -> int:
5050

5151
class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
5252
def __init__(
53-
self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs):
53+
self,
54+
clock=None,
55+
min_batch_size=1,
56+
max_batch_size=9999,
57+
multi_process_shared=False,
58+
**kwargs):
5459
self._fake_clock = clock
5560
self._min_batch_size = min_batch_size
5661
self._max_batch_size = max_batch_size
5762
self._env_vars = kwargs.get('env_vars', {})
63+
self._multi_process_shared = multi_process_shared
5864

5965
def load_model(self):
6066
if self._fake_clock:
@@ -66,6 +72,12 @@ def run_inference(
6672
batch: Sequence[int],
6773
model: FakeModel,
6874
inference_args=None) -> Iterable[int]:
75+
multi_process_shared_loaded = "multi_process_shared" in str(type(model))
76+
if self._multi_process_shared != multi_process_shared_loaded:
77+
raise Exception(
78+
f'Loaded model of type {type(model)}, was' +
79+
f'{"" if self._multi_process_shared else " not"} ' +
80+
'expecting multi_process_shared_model')
6981
if self._fake_clock:
7082
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
7183
for example in batch:
@@ -80,13 +92,21 @@ def batch_elements_kwargs(self):
8092
'max_batch_size': self._max_batch_size
8193
}
8294

95+
def share_model_across_processes(self):
96+
return self._multi_process_shared
97+
8398

8499
class FakeModelHandlerReturnsPredictionResult(
85100
base.ModelHandler[int, base.PredictionResult, FakeModel]):
86-
def __init__(self, clock=None, model_id='fake_model_id_default'):
101+
def __init__(
102+
self,
103+
clock=None,
104+
model_id='fake_model_id_default',
105+
multi_process_shared=False):
87106
self.model_id = model_id
88107
self._fake_clock = clock
89108
self._env_vars = {}
109+
self._multi_process_shared = multi_process_shared
90110

91111
def load_model(self):
92112
return FakeModel()
@@ -96,6 +116,12 @@ def run_inference(
96116
batch: Sequence[int],
97117
model: FakeModel,
98118
inference_args=None) -> Iterable[base.PredictionResult]:
119+
multi_process_shared_loaded = "multi_process_shared" in str(type(model))
120+
if self._multi_process_shared != multi_process_shared_loaded:
121+
raise Exception(
122+
f'Loaded model of type {type(model)}, was' +
123+
f'{"" if self._multi_process_shared else " not"} ' +
124+
'expecting multi_process_shared_model')
99125
for example in batch:
100126
yield base.PredictionResult(
101127
model_id=self.model_id,
@@ -105,6 +131,9 @@ def run_inference(
105131
def update_model_path(self, model_path: Optional[str] = None):
106132
self.model_id = model_path if model_path else self.model_id
107133

134+
def share_model_across_processes(self):
135+
return self._multi_process_shared
136+
108137

109138
class FakeModelHandlerNoEnvVars(base.ModelHandler[int, int, FakeModel]):
110139
def __init__(
@@ -188,6 +217,15 @@ def test_run_inference_impl_simple_examples(self):
188217
actual = pcoll | base.RunInference(FakeModelHandler())
189218
assert_that(actual, equal_to(expected), label='assert:inferences')
190219

220+
def test_run_inference_impl_simple_examples_multi_process_shared(self):
221+
with TestPipeline() as pipeline:
222+
examples = [1, 5, 3, 10]
223+
expected = [example + 1 for example in examples]
224+
pcoll = pipeline | 'start' >> beam.Create(examples)
225+
actual = pcoll | base.RunInference(
226+
FakeModelHandler(multi_process_shared=True))
227+
assert_that(actual, equal_to(expected), label='assert:inferences')
228+
191229
def test_run_inference_impl_with_keyed_examples(self):
192230
with TestPipeline() as pipeline:
193231
examples = [1, 5, 3, 10]
@@ -215,6 +253,35 @@ def test_run_inference_impl_with_maybe_keyed_examples(self):
215253
model_handler)
216254
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
217255

256+
def test_run_inference_impl_with_keyed_examples_multi_process_shared(self):
257+
with TestPipeline() as pipeline:
258+
examples = [1, 5, 3, 10]
259+
keyed_examples = [(i, example) for i, example in enumerate(examples)]
260+
expected = [(i, example + 1) for i, example in enumerate(examples)]
261+
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
262+
actual = pcoll | base.RunInference(
263+
base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True)))
264+
assert_that(actual, equal_to(expected), label='assert:inferences')
265+
266+
def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared(
267+
self):
268+
with TestPipeline() as pipeline:
269+
examples = [1, 5, 3, 10]
270+
keyed_examples = [(i, example) for i, example in enumerate(examples)]
271+
expected = [example + 1 for example in examples]
272+
keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
273+
model_handler = base.MaybeKeyedModelHandler(
274+
FakeModelHandler(multi_process_shared=True))
275+
276+
pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
277+
actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
278+
assert_that(actual, equal_to(expected), label='CheckUnkeyed')
279+
280+
keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
281+
keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
282+
model_handler)
283+
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
284+
218285
def test_run_inference_preprocessing(self):
219286
def mult_two(example: str) -> int:
220287
return int(example) * 2
@@ -666,6 +733,31 @@ def test_run_inference_with_iterable_side_input(self):
666733
'singleton view. First two elements encountered are' in str(
667734
e.exception))
668735

736+
def test_run_inference_with_iterable_side_input_multi_process_shared(self):
737+
test_pipeline = TestPipeline()
738+
side_input = (
739+
test_pipeline | "CreateDummySideInput" >> beam.Create(
740+
[base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
741+
| "ApplySideInputWindow" >> beam.WindowInto(
742+
window.GlobalWindows(),
743+
trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
744+
accumulation_mode=trigger.AccumulationMode.DISCARDING))
745+
746+
test_pipeline.options.view_as(StandardOptions).streaming = True
747+
with self.assertRaises(ValueError) as e:
748+
_ = (
749+
test_pipeline
750+
| beam.Create([1, 2, 3, 4])
751+
| base.RunInference(
752+
FakeModelHandler(multi_process_shared=True),
753+
model_metadata_pcoll=side_input))
754+
test_pipeline.run()
755+
756+
self.assertTrue(
757+
'PCollection of size 2 with more than one element accessed as a '
758+
'singleton view. First two elements encountered are' in str(
759+
e.exception))
760+
669761
def test_run_inference_empty_side_input(self):
670762
model_handler = FakeModelHandlerReturnsPredictionResult()
671763
main_input_elements = [1, 2]
@@ -759,6 +851,84 @@ def process(self, element):
759851

760852
assert_that(result_pcoll, equal_to(expected_result))
761853

854+
def test_run_inference_side_input_in_batch_multi_process_shared(self):
855+
first_ts = math.floor(time.time()) - 30
856+
interval = 7
857+
858+
sample_main_input_elements = ([
859+
first_ts - 2,
860+
first_ts + 1,
861+
first_ts + 8,
862+
first_ts + 15,
863+
first_ts + 22,
864+
])
865+
866+
sample_side_input_elements = [
867+
(first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
868+
# if model_id is empty string, we use the default model
869+
# handler model URI.
870+
(
871+
first_ts + 8,
872+
base.ModelMetadata(
873+
model_id='fake_model_id_1', model_name='fake_model_id_1')),
874+
(
875+
first_ts + 15,
876+
base.ModelMetadata(
877+
model_id='fake_model_id_2', model_name='fake_model_id_2'))
878+
]
879+
880+
model_handler = FakeModelHandlerReturnsPredictionResult(
881+
multi_process_shared=True)
882+
883+
# applying GroupByKey to utilize windowing according to
884+
# https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
885+
class _EmitElement(beam.DoFn):
886+
def process(self, element):
887+
for e in element:
888+
yield e
889+
890+
with TestPipeline() as pipeline:
891+
side_input = (
892+
pipeline
893+
|
894+
"CreateSideInputElements" >> beam.Create(sample_side_input_elements)
895+
| beam.Map(lambda x: TimestampedValue(x[1], x[0]))
896+
| beam.WindowInto(
897+
window.FixedWindows(interval),
898+
accumulation_mode=trigger.AccumulationMode.DISCARDING)
899+
| beam.Map(lambda x: ('key', x))
900+
| beam.GroupByKey()
901+
| beam.Map(lambda x: x[1])
902+
| "EmitSideInput" >> beam.ParDo(_EmitElement()))
903+
904+
result_pcoll = (
905+
pipeline
906+
| beam.Create(sample_main_input_elements)
907+
| "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
908+
| "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
909+
| beam.Map(lambda x: ('key', x))
910+
| "MainInputGBK" >> beam.GroupByKey()
911+
| beam.Map(lambda x: x[1])
912+
| beam.ParDo(_EmitElement())
913+
| "RunInference" >> base.RunInference(
914+
model_handler, model_metadata_pcoll=side_input))
915+
916+
expected_model_id_order = [
917+
'fake_model_id_default',
918+
'fake_model_id_default',
919+
'fake_model_id_1',
920+
'fake_model_id_2',
921+
'fake_model_id_2'
922+
]
923+
expected_result = [
924+
base.PredictionResult(
925+
example=sample_main_input_elements[i],
926+
inference=sample_main_input_elements[i] + 1,
927+
model_id=expected_model_id_order[i]) for i in range(5)
928+
]
929+
930+
assert_that(result_pcoll, equal_to(expected_result))
931+
762932
@unittest.skipIf(
763933
not TestPipeline().get_pipeline_options().view_as(
764934
StandardOptions).streaming,

0 commit comments

Comments
 (0)