@@ -50,11 +50,17 @@ def predict(self, example: int) -> int:
5050
5151class FakeModelHandler (base .ModelHandler [int , int , FakeModel ]):
5252 def __init__ (
53- self , clock = None , min_batch_size = 1 , max_batch_size = 9999 , ** kwargs ):
53+ self ,
54+ clock = None ,
55+ min_batch_size = 1 ,
56+ max_batch_size = 9999 ,
57+ multi_process_shared = False ,
58+ ** kwargs ):
5459 self ._fake_clock = clock
5560 self ._min_batch_size = min_batch_size
5661 self ._max_batch_size = max_batch_size
5762 self ._env_vars = kwargs .get ('env_vars' , {})
63+ self ._multi_process_shared = multi_process_shared
5864
5965 def load_model (self ):
6066 if self ._fake_clock :
@@ -66,6 +72,12 @@ def run_inference(
6672 batch : Sequence [int ],
6773 model : FakeModel ,
6874 inference_args = None ) -> Iterable [int ]:
75+ multi_process_shared_loaded = "multi_process_shared" in str (type (model ))
76+ if self ._multi_process_shared != multi_process_shared_loaded :
77+ raise Exception (
78+ f'Loaded model of type { type (model )} , was' +
79+ f'{ "" if self ._multi_process_shared else " not" } ' +
80+ 'expecting multi_process_shared_model' )
6981 if self ._fake_clock :
7082 self ._fake_clock .current_time_ns += 3_000_000 # 3 milliseconds
7183 for example in batch :
@@ -80,13 +92,21 @@ def batch_elements_kwargs(self):
8092 'max_batch_size' : self ._max_batch_size
8193 }
8294
95+ def share_model_across_processes (self ):
96+ return self ._multi_process_shared
97+
8398
8499class FakeModelHandlerReturnsPredictionResult (
85100 base .ModelHandler [int , base .PredictionResult , FakeModel ]):
86- def __init__ (self , clock = None , model_id = 'fake_model_id_default' ):
101+ def __init__ (
102+ self ,
103+ clock = None ,
104+ model_id = 'fake_model_id_default' ,
105+ multi_process_shared = False ):
87106 self .model_id = model_id
88107 self ._fake_clock = clock
89108 self ._env_vars = {}
109+ self ._multi_process_shared = multi_process_shared
90110
91111 def load_model (self ):
92112 return FakeModel ()
@@ -96,6 +116,12 @@ def run_inference(
96116 batch : Sequence [int ],
97117 model : FakeModel ,
98118 inference_args = None ) -> Iterable [base .PredictionResult ]:
119+ multi_process_shared_loaded = "multi_process_shared" in str (type (model ))
120+ if self ._multi_process_shared != multi_process_shared_loaded :
121+ raise Exception (
122+ f'Loaded model of type { type (model )} , was' +
123+ f'{ "" if self ._multi_process_shared else " not" } ' +
124+ 'expecting multi_process_shared_model' )
99125 for example in batch :
100126 yield base .PredictionResult (
101127 model_id = self .model_id ,
@@ -105,6 +131,9 @@ def run_inference(
105131 def update_model_path (self , model_path : Optional [str ] = None ):
106132 self .model_id = model_path if model_path else self .model_id
107133
134+ def share_model_across_processes (self ):
135+ return self ._multi_process_shared
136+
108137
109138class FakeModelHandlerNoEnvVars (base .ModelHandler [int , int , FakeModel ]):
110139 def __init__ (
@@ -188,6 +217,15 @@ def test_run_inference_impl_simple_examples(self):
188217 actual = pcoll | base .RunInference (FakeModelHandler ())
189218 assert_that (actual , equal_to (expected ), label = 'assert:inferences' )
190219
220+ def test_run_inference_impl_simple_examples_multi_process_shared (self ):
221+ with TestPipeline () as pipeline :
222+ examples = [1 , 5 , 3 , 10 ]
223+ expected = [example + 1 for example in examples ]
224+ pcoll = pipeline | 'start' >> beam .Create (examples )
225+ actual = pcoll | base .RunInference (
226+ FakeModelHandler (multi_process_shared = True ))
227+ assert_that (actual , equal_to (expected ), label = 'assert:inferences' )
228+
191229 def test_run_inference_impl_with_keyed_examples (self ):
192230 with TestPipeline () as pipeline :
193231 examples = [1 , 5 , 3 , 10 ]
@@ -215,6 +253,35 @@ def test_run_inference_impl_with_maybe_keyed_examples(self):
215253 model_handler )
216254 assert_that (keyed_actual , equal_to (keyed_expected ), label = 'CheckKeyed' )
217255
256+ def test_run_inference_impl_with_keyed_examples_multi_process_shared (self ):
257+ with TestPipeline () as pipeline :
258+ examples = [1 , 5 , 3 , 10 ]
259+ keyed_examples = [(i , example ) for i , example in enumerate (examples )]
260+ expected = [(i , example + 1 ) for i , example in enumerate (examples )]
261+ pcoll = pipeline | 'start' >> beam .Create (keyed_examples )
262+ actual = pcoll | base .RunInference (
263+ base .KeyedModelHandler (FakeModelHandler (multi_process_shared = True )))
264+ assert_that (actual , equal_to (expected ), label = 'assert:inferences' )
265+
266+ def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared (
267+ self ):
268+ with TestPipeline () as pipeline :
269+ examples = [1 , 5 , 3 , 10 ]
270+ keyed_examples = [(i , example ) for i , example in enumerate (examples )]
271+ expected = [example + 1 for example in examples ]
272+ keyed_expected = [(i , example + 1 ) for i , example in enumerate (examples )]
273+ model_handler = base .MaybeKeyedModelHandler (
274+ FakeModelHandler (multi_process_shared = True ))
275+
276+ pcoll = pipeline | 'Unkeyed' >> beam .Create (examples )
277+ actual = pcoll | 'RunUnkeyed' >> base .RunInference (model_handler )
278+ assert_that (actual , equal_to (expected ), label = 'CheckUnkeyed' )
279+
280+ keyed_pcoll = pipeline | 'Keyed' >> beam .Create (keyed_examples )
281+ keyed_actual = keyed_pcoll | 'RunKeyed' >> base .RunInference (
282+ model_handler )
283+ assert_that (keyed_actual , equal_to (keyed_expected ), label = 'CheckKeyed' )
284+
218285 def test_run_inference_preprocessing (self ):
219286 def mult_two (example : str ) -> int :
220287 return int (example ) * 2
@@ -666,6 +733,31 @@ def test_run_inference_with_iterable_side_input(self):
666733 'singleton view. First two elements encountered are' in str (
667734 e .exception ))
668735
736+ def test_run_inference_with_iterable_side_input_multi_process_shared (self ):
737+ test_pipeline = TestPipeline ()
738+ side_input = (
739+ test_pipeline | "CreateDummySideInput" >> beam .Create (
740+ [base .ModelMetadata (1 , 1 ), base .ModelMetadata (2 , 2 )])
741+ | "ApplySideInputWindow" >> beam .WindowInto (
742+ window .GlobalWindows (),
743+ trigger = trigger .Repeatedly (trigger .AfterProcessingTime (1 )),
744+ accumulation_mode = trigger .AccumulationMode .DISCARDING ))
745+
746+ test_pipeline .options .view_as (StandardOptions ).streaming = True
747+ with self .assertRaises (ValueError ) as e :
748+ _ = (
749+ test_pipeline
750+ | beam .Create ([1 , 2 , 3 , 4 ])
751+ | base .RunInference (
752+ FakeModelHandler (multi_process_shared = True ),
753+ model_metadata_pcoll = side_input ))
754+ test_pipeline .run ()
755+
756+ self .assertTrue (
757+ 'PCollection of size 2 with more than one element accessed as a '
758+ 'singleton view. First two elements encountered are' in str (
759+ e .exception ))
760+
669761 def test_run_inference_empty_side_input (self ):
670762 model_handler = FakeModelHandlerReturnsPredictionResult ()
671763 main_input_elements = [1 , 2 ]
@@ -759,6 +851,84 @@ def process(self, element):
759851
760852 assert_that (result_pcoll , equal_to (expected_result ))
761853
854+ def test_run_inference_side_input_in_batch_multi_process_shared (self ):
855+ first_ts = math .floor (time .time ()) - 30
856+ interval = 7
857+
858+ sample_main_input_elements = ([
859+ first_ts - 2 ,
860+ first_ts + 1 ,
861+ first_ts + 8 ,
862+ first_ts + 15 ,
863+ first_ts + 22 ,
864+ ])
865+
866+ sample_side_input_elements = [
867+ (first_ts + 1 , base .ModelMetadata (model_id = '' , model_name = '' )),
868+ # if model_id is empty string, we use the default model
869+ # handler model URI.
870+ (
871+ first_ts + 8 ,
872+ base .ModelMetadata (
873+ model_id = 'fake_model_id_1' , model_name = 'fake_model_id_1' )),
874+ (
875+ first_ts + 15 ,
876+ base .ModelMetadata (
877+ model_id = 'fake_model_id_2' , model_name = 'fake_model_id_2' ))
878+ ]
879+
880+ model_handler = FakeModelHandlerReturnsPredictionResult (
881+ multi_process_shared = True )
882+
883+ # applying GroupByKey to utilize windowing according to
884+ # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
885+ class _EmitElement (beam .DoFn ):
886+ def process (self , element ):
887+ for e in element :
888+ yield e
889+
890+ with TestPipeline () as pipeline :
891+ side_input = (
892+ pipeline
893+ |
894+ "CreateSideInputElements" >> beam .Create (sample_side_input_elements )
895+ | beam .Map (lambda x : TimestampedValue (x [1 ], x [0 ]))
896+ | beam .WindowInto (
897+ window .FixedWindows (interval ),
898+ accumulation_mode = trigger .AccumulationMode .DISCARDING )
899+ | beam .Map (lambda x : ('key' , x ))
900+ | beam .GroupByKey ()
901+ | beam .Map (lambda x : x [1 ])
902+ | "EmitSideInput" >> beam .ParDo (_EmitElement ()))
903+
904+ result_pcoll = (
905+ pipeline
906+ | beam .Create (sample_main_input_elements )
907+ | "MapTimeStamp" >> beam .Map (lambda x : TimestampedValue (x , x ))
908+ | "ApplyWindow" >> beam .WindowInto (window .FixedWindows (interval ))
909+ | beam .Map (lambda x : ('key' , x ))
910+ | "MainInputGBK" >> beam .GroupByKey ()
911+ | beam .Map (lambda x : x [1 ])
912+ | beam .ParDo (_EmitElement ())
913+ | "RunInference" >> base .RunInference (
914+ model_handler , model_metadata_pcoll = side_input ))
915+
916+ expected_model_id_order = [
917+ 'fake_model_id_default' ,
918+ 'fake_model_id_default' ,
919+ 'fake_model_id_1' ,
920+ 'fake_model_id_2' ,
921+ 'fake_model_id_2'
922+ ]
923+ expected_result = [
924+ base .PredictionResult (
925+ example = sample_main_input_elements [i ],
926+ inference = sample_main_input_elements [i ] + 1 ,
927+ model_id = expected_model_id_order [i ]) for i in range (5 )
928+ ]
929+
930+ assert_that (result_pcoll , equal_to (expected_result ))
931+
762932 @unittest .skipIf (
763933 not TestPipeline ().get_pipeline_options ().view_as (
764934 StandardOptions ).streaming ,
0 commit comments