Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d38645a

Browse files
boyuanzzrobertwb
authored andcommitted
[BEAM-6778] Enable Bundle Finalization in Python SDK harness over FnApi (#7937)
1 parent 4da6e93 commit d38645a

10 files changed

Lines changed: 222 additions & 31 deletions

File tree

sdks/python/apache_beam/runners/common.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cdef class DoFnInvoker(object):
5656
cdef public DoFnSignature signature
5757
cdef OutputProcessor output_processor
5858
cdef object user_state_context
59+
cdef public object bundle_finalizer_param
5960

6061
cpdef invoke_process(self, WindowedValue windowed_value,
6162
restriction_tracker=*,
@@ -92,7 +93,7 @@ cdef class DoFnRunner(Receiver):
9293
cdef object step_name
9394
cdef list side_inputs
9495
cdef DoFnInvoker do_fn_invoker
95-
96+
cdef public object bundle_finalizer_param
9697
cpdef process(self, WindowedValue windowed_value)
9798

9899

sdks/python/apache_beam/runners/common.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,16 @@ def __init__(self, output_processor, signature):
292292
self.output_processor = output_processor
293293
self.signature = signature
294294
self.user_state_context = None
295+
self.bundle_finalizer_param = None
295296

296297
@staticmethod
297298
def create_invoker(
298299
signature,
299300
output_processor=None,
300301
context=None, side_inputs=None, input_args=None, input_kwargs=None,
301302
process_invocation=True,
302-
user_state_context=None):
303+
user_state_context=None,
304+
bundle_finalizer_param=None):
303305
""" Creates a new DoFnInvoker based on given arguments.
304306
305307
Args:
@@ -321,6 +323,8 @@ def create_invoker(
321323
method efficiently.
322324
user_state_context: The UserStateContext instance for the current
323325
Stateful DoFn.
326+
bundle_finalizer_param: The param that passed to a process method, which
327+
allows a callback to be registered.
324328
"""
325329
side_inputs = side_inputs or []
326330
default_arg_values = signature.process_method.defaults
@@ -333,7 +337,7 @@ def create_invoker(
333337
return PerWindowInvoker(
334338
output_processor,
335339
signature, context, side_inputs, input_args, input_kwargs,
336-
user_state_context)
340+
user_state_context, bundle_finalizer_param)
337341

338342
def invoke_process(self, windowed_value, restriction_tracker=None,
339343
output_processor=None,
@@ -423,7 +427,8 @@ class PerWindowInvoker(DoFnInvoker):
423427
"""An invoker that processes elements considering windowing information."""
424428

425429
def __init__(self, output_processor, signature, context,
426-
side_inputs, input_args, input_kwargs, user_state_context):
430+
side_inputs, input_args, input_kwargs, user_state_context,
431+
bundle_finalizer_param):
427432
super(PerWindowInvoker, self).__init__(output_processor, signature)
428433
self.side_inputs = side_inputs
429434
self.context = context
@@ -437,6 +442,7 @@ def __init__(self, output_processor, signature, context,
437442
self.is_splittable = signature.is_splittable_dofn()
438443
self.restriction_tracker = None
439444
self.current_windowed_value = None
445+
self.bundle_finalizer_param = bundle_finalizer_param
440446

441447
# Try to prepare all the arguments that can just be filled in
442448
# without any additional work. in the process function.
@@ -487,6 +493,8 @@ def __init__(self, placeholder):
487493
args_with_placeholders.append(ArgPlaceholder(d))
488494
elif isinstance(d, core.DoFn.TimerParam):
489495
args_with_placeholders.append(ArgPlaceholder(d))
496+
elif d == core.DoFn.BundleFinalizerParam:
497+
args_with_placeholders.append(ArgPlaceholder(d))
490498
else:
491499
# If no more args are present then the value must be passed via kwarg
492500
try:
@@ -608,6 +616,8 @@ def _invoke_per_window(
608616
elif isinstance(p, core.DoFn.TimerParam):
609617
args_for_process[i] = (
610618
self.user_state_context.get_timer(p.timer_spec, key, window))
619+
elif p == core.DoFn.BundleFinalizerParam:
620+
args_for_process[i] = self.bundle_finalizer_param
611621

612622
if additional_kwargs:
613623
if kwargs_for_process is None:
@@ -694,6 +704,7 @@ def __init__(self,
694704

695705
self.step_name = step_name
696706
self.context = DoFnContext(step_name, state=state)
707+
self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
697708

698709
do_fn_signature = DoFnSignature(fn)
699710

@@ -722,7 +733,8 @@ def __init__(self,
722733

723734
self.do_fn_invoker = DoFnInvoker.create_invoker(
724735
do_fn_signature, output_processor, self.context, side_inputs, args,
725-
kwargs, user_state_context=user_state_context)
736+
kwargs, user_state_context=user_state_context,
737+
bundle_finalizer_param=self.bundle_finalizer_param)
726738

727739
def receive(self, windowed_value):
728740
self.process(windowed_value)
@@ -733,6 +745,9 @@ def process(self, windowed_value):
733745
except BaseException as exn:
734746
self._reraise_augmented(exn)
735747

748+
def finalize(self):
749+
self.bundle_finalizer_param.finalize_bundle()
750+
736751
def process_with_restriction(self, windowed_value):
737752
element, restriction = windowed_value.value
738753
return self.do_fn_invoker.invoke_process(

sdks/python/apache_beam/runners/portability/flink_runner_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@ def process(self, v):
220220
def test_sdf(self):
221221
raise unittest.SkipTest("BEAM-2939")
222222

223+
def test_callbacks_with_exception(self):
224+
raise unittest.SkipTest("BEAM-6868")
225+
226+
def test_register_finalizations(self):
227+
raise unittest.SkipTest("BEAM-6868")
228+
223229
# Inherits all other tests.
224230

225231
# Run the tests.

sdks/python/apache_beam/runners/portability/fn_api_runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,15 @@ def process_bundle(self, inputs, expected_outputs):
12981298
if result.error:
12991299
raise RuntimeError(result.error)
13001300

1301+
if result.process_bundle.requires_finalization:
1302+
finalize_request = beam_fn_api_pb2.InstructionRequest(
1303+
finalize_bundle=
1304+
beam_fn_api_pb2.FinalizeBundleRequest(
1305+
instruction_reference=process_bundle_id
1306+
))
1307+
self._controller.control_handler.push(
1308+
finalize_request)
1309+
13011310
return result, split_results
13021311

13031312

sdks/python/apache_beam/runners/portability/fn_api_runner_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import os
2323
import random
24+
import shutil
2425
import sys
2526
import tempfile
2627
import threading
@@ -801,6 +802,50 @@ def contains_labels(monitoring_info, labels):
801802
print(res._monitoring_infos_by_stage)
802803
raise
803804

805+
def test_callbacks_with_exception(self):
806+
elements_list = ['1', '2']
807+
808+
def raise_expetion():
809+
raise Exception('raise exception when calling callback')
810+
811+
class FinalizebleDoFnWithException(beam.DoFn):
812+
813+
def process(
814+
self,
815+
element,
816+
bundle_finalizer=beam.DoFn.BundleFinalizerParam):
817+
bundle_finalizer.register(raise_expetion)
818+
yield element
819+
820+
with self.create_pipeline() as p:
821+
res = (p
822+
| beam.Create(elements_list)
823+
| beam.ParDo(FinalizebleDoFnWithException()))
824+
assert_that(res, equal_to(['1', '2']))
825+
826+
def test_register_finalizations(self):
827+
event_recorder = EventRecorder(tempfile.gettempdir())
828+
elements_list = ['2', '1']
829+
830+
class FinalizableDoFn(beam.DoFn):
831+
def process(
832+
self,
833+
element,
834+
bundle_finalizer=beam.DoFn.BundleFinalizerParam):
835+
bundle_finalizer.register(lambda: event_recorder.record(element))
836+
yield element
837+
838+
with self.create_pipeline() as p:
839+
res = (p
840+
| beam.Create(elements_list)
841+
| beam.ParDo(FinalizableDoFn()))
842+
843+
assert_that(res, equal_to(elements_list))
844+
845+
results = event_recorder.events()
846+
event_recorder.cleanup()
847+
self.assertEquals(results, sorted(elements_list))
848+
804849

805850
class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
806851

@@ -827,6 +872,9 @@ def create_pipeline(self):
827872
return beam.Pipeline(
828873
runner=fn_api_runner.FnApiRunner(bundle_repeat=3))
829874

875+
def test_register_finalizations(self):
876+
raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.")
877+
830878

831879
class FnApiRunnerSplitTest(unittest.TestCase):
832880

@@ -1084,6 +1132,34 @@ def _unpickle_element_counter(name):
10841132
return _pickled_element_counters[name]
10851133

10861134

1135+
class EventRecorder(object):
1136+
"""Used to be registered as a callback in bundle finalization.
1137+
1138+
The reason why records are written into a tmp file is, the in-memory dataset
1139+
cannot keep callback records when passing into one DoFn.
1140+
"""
1141+
def __init__(self, tmp_dir):
1142+
self.tmp_dir = os.path.join(tmp_dir, uuid.uuid4().hex)
1143+
os.mkdir(self.tmp_dir)
1144+
1145+
def record(self, content):
1146+
file_path = os.path.join(self.tmp_dir, uuid.uuid4().hex + '.txt')
1147+
with open(file_path, 'w') as f:
1148+
f.write(content)
1149+
1150+
def events(self):
1151+
content = []
1152+
record_files = [f for f in os.listdir(self.tmp_dir) if os.path.isfile(
1153+
os.path.join(self.tmp_dir, f))]
1154+
for file in record_files:
1155+
with open(os.path.join(self.tmp_dir, file), 'r') as f:
1156+
content.append(f.read())
1157+
return sorted(content)
1158+
1159+
def cleanup(self):
1160+
shutil.rmtree(self.tmp_dir)
1161+
1162+
10871163
if __name__ == '__main__':
10881164
logging.getLogger().setLevel(logging.INFO)
10891165
unittest.main()

sdks/python/apache_beam/runners/worker/bundle_processor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ def __init__(
469469
self.splitting_lock = threading.Lock()
470470

471471
def create_execution_tree(self, descriptor):
472-
473472
transform_factory = BeamTransformFactory(
474473
descriptor, self.data_channel_factory, self.counter_factory,
475474
self.state_sampler, self.state_handler)
@@ -559,16 +558,24 @@ def process_bundle(self, instruction_id):
559558
logging.debug('finish %s', op)
560559
op.finish()
561560

562-
return [
563-
self.delayed_bundle_application(op, residual)
564-
for op, residual in execution_context.delayed_applications]
561+
return ([self.delayed_bundle_application(op, residual)
562+
for op, residual in execution_context.delayed_applications],
563+
self.requires_finalization())
565564

566565
finally:
567566
# Ensure any in-flight split attempts complete.
568567
with self.splitting_lock:
569568
pass
570569
self.state_sampler.stop_if_still_running()
571570

571+
def finalize_bundle(self):
572+
for op in self.ops.values():
573+
op.finalize_bundle()
574+
return beam_fn_api_pb2.FinalizeBundleResponse()
575+
576+
def requires_finalization(self):
577+
return any(op.needs_finalization() for op in self.ops.values())
578+
572579
def try_split(self, bundle_split_request):
573580
split_response = beam_fn_api_pb2.ProcessBundleSplitResponse()
574581
with self.splitting_lock:

sdks/python/apache_beam/runners/worker/operations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def process(self, o):
223223
"""Process element in operation."""
224224
pass
225225

226+
def finalize_bundle(self):
227+
pass
228+
229+
def needs_finalization(self):
230+
return False
231+
226232
def try_split(self, fraction_of_remainder):
227233
return None
228234

@@ -557,6 +563,12 @@ def process(self, o):
557563
self.execution_context.delayed_applications.append(
558564
(self, delayed_application))
559565

566+
def finalize_bundle(self):
567+
self.dofn_receiver.finalize()
568+
569+
def needs_finalization(self):
570+
return self.dofn_receiver.bundle_finalizer_param.has_callbacks()
571+
560572
def process_timer(self, tag, windowed_timer):
561573
key, timer_data = windowed_timer.value
562574
timer_spec = self.timer_specs[tag]
@@ -575,6 +587,7 @@ def reset(self):
575587
side_input_map.reset()
576588
if self.user_state_context:
577589
self.user_state_context.reset()
590+
self.dofn_receiver.bundle_finalizer_param.reset()
578591

579592
def progress_metrics(self):
580593
metrics = super(DoOperation, self).progress_metrics()

0 commit comments

Comments
 (0)