@@ -292,14 +292,16 @@ def __init__(self, output_processor, signature):
292292 self .output_processor = output_processor
293293 self .signature = signature
294294 self .user_state_context = None
295+ self .bundle_finalizer_param = None
295296
296297 @staticmethod
297298 def create_invoker (
298299 signature ,
299300 output_processor = None ,
300301 context = None , side_inputs = None , input_args = None , input_kwargs = None ,
301302 process_invocation = True ,
302- user_state_context = None ):
303+ user_state_context = None ,
304+ bundle_finalizer_param = None ):
303305 """ Creates a new DoFnInvoker based on given arguments.
304306
305307 Args:
@@ -321,6 +323,8 @@ def create_invoker(
321323 method efficiently.
322324 user_state_context: The UserStateContext instance for the current
323325 Stateful DoFn.
326+ bundle_finalizer_param: The param that passed to a process method, which
327+ allows a callback to be registered.
324328 """
325329 side_inputs = side_inputs or []
326330 default_arg_values = signature .process_method .defaults
@@ -333,7 +337,7 @@ def create_invoker(
333337 return PerWindowInvoker (
334338 output_processor ,
335339 signature , context , side_inputs , input_args , input_kwargs ,
336- user_state_context )
340+ user_state_context , bundle_finalizer_param )
337341
338342 def invoke_process (self , windowed_value , restriction_tracker = None ,
339343 output_processor = None ,
@@ -423,7 +427,8 @@ class PerWindowInvoker(DoFnInvoker):
423427 """An invoker that processes elements considering windowing information."""
424428
425429 def __init__ (self , output_processor , signature , context ,
426- side_inputs , input_args , input_kwargs , user_state_context ):
430+ side_inputs , input_args , input_kwargs , user_state_context ,
431+ bundle_finalizer_param ):
427432 super (PerWindowInvoker , self ).__init__ (output_processor , signature )
428433 self .side_inputs = side_inputs
429434 self .context = context
@@ -437,6 +442,7 @@ def __init__(self, output_processor, signature, context,
437442 self .is_splittable = signature .is_splittable_dofn ()
438443 self .restriction_tracker = None
439444 self .current_windowed_value = None
445+ self .bundle_finalizer_param = bundle_finalizer_param
440446
441447 # Try to prepare all the arguments that can just be filled in
442448 # without any additional work. in the process function.
@@ -487,6 +493,8 @@ def __init__(self, placeholder):
487493 args_with_placeholders .append (ArgPlaceholder (d ))
488494 elif isinstance (d , core .DoFn .TimerParam ):
489495 args_with_placeholders .append (ArgPlaceholder (d ))
496+ elif d == core .DoFn .BundleFinalizerParam :
497+ args_with_placeholders .append (ArgPlaceholder (d ))
490498 else :
491499 # If no more args are present then the value must be passed via kwarg
492500 try :
@@ -608,6 +616,8 @@ def _invoke_per_window(
608616 elif isinstance (p , core .DoFn .TimerParam ):
609617 args_for_process [i ] = (
610618 self .user_state_context .get_timer (p .timer_spec , key , window ))
619+ elif p == core .DoFn .BundleFinalizerParam :
620+ args_for_process [i ] = self .bundle_finalizer_param
611621
612622 if additional_kwargs :
613623 if kwargs_for_process is None :
@@ -694,6 +704,7 @@ def __init__(self,
694704
695705 self .step_name = step_name
696706 self .context = DoFnContext (step_name , state = state )
707+ self .bundle_finalizer_param = DoFn .BundleFinalizerParam ()
697708
698709 do_fn_signature = DoFnSignature (fn )
699710
@@ -722,7 +733,8 @@ def __init__(self,
722733
723734 self .do_fn_invoker = DoFnInvoker .create_invoker (
724735 do_fn_signature , output_processor , self .context , side_inputs , args ,
725- kwargs , user_state_context = user_state_context )
736+ kwargs , user_state_context = user_state_context ,
737+ bundle_finalizer_param = self .bundle_finalizer_param )
726738
727739 def receive (self , windowed_value ):
728740 self .process (windowed_value )
@@ -733,6 +745,9 @@ def process(self, windowed_value):
733745 except BaseException as exn :
734746 self ._reraise_augmented (exn )
735747
748+ def finalize (self ):
749+ self .bundle_finalizer_param .finalize_bundle ()
750+
736751 def process_with_restriction (self , windowed_value ):
737752 element , restriction = windowed_value .value
738753 return self .do_fn_invoker .invoke_process (
0 commit comments