Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c2dfa8d

Browse files
author
Boyuan Zhang
committed
Clean up and add type-hints to SDF API
1 parent 00ed8a8 commit c2dfa8d

4 files changed

Lines changed: 23 additions & 20 deletions

File tree

sdks/python/apache_beam/io/iobase.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,11 +1263,13 @@ def get_estimator_state(self):
12631263
raise NotImplementedError(type(self))
12641264

12651265
def current_watermark(self):
1266+
# type: () -> timestamp.Timestamp
12661267
"""Return estimated output_watermark. This function must return
12671268
monotonically increasing watermarks."""
12681269
raise NotImplementedError(type(self))
12691270

12701271
def observe_timestamp(self, timestamp):
1272+
# type: (timestamp.Timestamp) -> None
12711273
"""Update tracking watermark with latest output timestamp.
12721274
12731275
Args:

sdks/python/apache_beam/runners/common.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
if TYPE_CHECKING:
6868
from apache_beam.transforms import sideinputs
6969
from apache_beam.transforms.core import TimerSpec
70+
from apache_beam.iobase import WatermarkEstimator
71+
from apache_beam.iobase import RestrictionTracker
7072

7173

7274
class NameContext(object):
@@ -296,6 +298,7 @@ def get_restriction_provider(self):
296298
return self.process_method.restriction_provider
297299

298300
def get_watermark_estimator_provider(self):
301+
# type: () -> WatermarkEstimatorProvider
299302
return self.process_method.watermark_estimator_provider
300303

301304
def _validate(self):
@@ -333,6 +336,7 @@ def is_splittable_dofn(self):
333336
return self.get_restriction_provider() is not None
334337

335338
def get_restriction_coder(self):
339+
# type: () -> Optional[TupleCoder]
336340
"""Get coder for a restriction when processing an SDF. """
337341
if self.is_splittable_dofn():
338342
restriction_coder = TupleCoder([
@@ -437,11 +441,11 @@ def create_invoker(
437441
def invoke_process(self,
438442
windowed_value, # type: WindowedValue
439443
restriction_tracker=None, # type: Optional[RestrictionTracker]
440-
watermark_estimator=None,
444+
watermark_estimator=None, # type: Optional[WatermarkEstimator]
441445
additional_args=None,
442446
additional_kwargs=None
443447
):
444-
# type: (...) -> Optional[SplitResultType]
448+
# type: (...) -> Optional[SplitResultResidual]
445449

446450
"""Invokes the DoFn.process() function.
447451
@@ -524,7 +528,7 @@ def __init__(self,
524528
def invoke_process(self,
525529
windowed_value, # type: WindowedValue
526530
restriction_tracker=None, # type: Optional[RestrictionTracker]
527-
watermark_estimator=None,
531+
watermark_estimator=None, # type: Optional[WatermarkEstimator]
528532
additional_args=None,
529533
additional_kwargs=None
530534
):
@@ -557,8 +561,8 @@ def __init__(self,
557561
signature.is_stateful_dofn())
558562
self.user_state_context = user_state_context
559563
self.is_splittable = signature.is_splittable_dofn()
560-
self.threadsafe_restriction_tracker = None
561-
self.threadsafe_watermark_estimator = None
564+
self.threadsafe_restriction_tracker = None # type: Optional[ThreadsafeRestrictionTracker]
565+
self.threadsafe_watermark_estimator = None # type: Optional[ThreadsafeWatermarkEstimator]
562566
self.current_windowed_value = None # type: Optional[WindowedValue]
563567
self.bundle_finalizer_param = bundle_finalizer_param
564568
self.is_key_param_required = False
@@ -640,12 +644,12 @@ def __init__(self, placeholder):
640644

641645
def invoke_process(self,
642646
windowed_value, # type: WindowedValue
643-
restriction_tracker=None,
644-
watermark_estimator=None,
647+
restriction_tracker=None, # type: Optional[RestrictionTracker]
648+
watermark_estimator=None, # type: Optional[WatermarkEstimator]
645649
additional_args=None,
646650
additional_kwargs=None
647651
):
648-
# type: (...) -> Optional[SplitResultType]
652+
# type: (...) -> Optional[SplitResultResidual]
649653
if not additional_args:
650654
additional_args = []
651655
if not additional_kwargs:
@@ -790,9 +794,6 @@ def _invoke_process_per_window(self,
790794

791795
if self.is_splittable:
792796
assert self.threadsafe_restriction_tracker is not None
793-
# TODO: Consider calling check_done right after SDF.Process() finishing.
794-
# In order to do this, we need to know that current invoking dofn is
795-
# ProcessSizedElementAndRestriction.
796797
self.threadsafe_restriction_tracker.check_done()
797798
deferred_status = self.threadsafe_restriction_tracker.deferred_status()
798799
if deferred_status:

sdks/python/apache_beam/runners/sdf_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
if TYPE_CHECKING:
4040
from apache_beam.io.iobase import RestrictionTracker
41+
from apache_beam.io.iobase import RestrictionProgress
42+
from apache_beam.io.iobase import WatermarkEstimator
4143

4244
_LOGGER = logging.getLogger(__name__)
4345

@@ -113,6 +115,7 @@ def check_done(self):
113115
return self._restriction_tracker.check_done()
114116

115117
def current_progress(self):
118+
# type: () -> RestrictionProgress
116119
with self._lock:
117120
return self._restriction_tracker.current_progress()
118121

@@ -158,6 +161,7 @@ class RestrictionTrackerView(object):
158161
restriction_tracker.
159162
"""
160163
def __init__(self, threadsafe_restriction_tracker):
164+
# type: (ThreadsafeRestrictionTracker) -> None
161165
if not isinstance(threadsafe_restriction_tracker,
162166
ThreadsafeRestrictionTracker):
163167
raise ValueError(
@@ -180,6 +184,7 @@ class ThreadsafeWatermarkEstimator(object):
180184
mechanism to guarantee multi-thread safety.
181185
"""
182186
def __init__(self, watermark_estimator):
187+
# type: (WatermarkEstimator) -> None
183188
from apache_beam.io.iobase import WatermarkEstimator
184189
if not isinstance(watermark_estimator, WatermarkEstimator):
185190
raise ValueError('Initializing Threadsafe requires a WatermarkEstimator')
@@ -200,19 +205,13 @@ def get_estimator_state(self):
200205
with self._lock:
201206
return self._watermark_estimator.get_estimator_state()
202207

203-
def current_watermark_with_lock(self):
204-
# The caller should hold the lock before entering this function.
205-
if not self._lock.locked():
206-
raise RuntimeError(
207-
'Expected lock to be held to guarantee thread-safe '
208-
'access.')
209-
return self._watermark_estimator.current_watermark()
210-
211208
def current_watermark(self):
209+
# type: () -> Timestamp
212210
with self._lock:
213-
return self.current_watermark_with_lock()
211+
return self._watermark_estimator.current_watermark()
214212

215213
def observe_timestamp(self, timestamp):
214+
# type: (Timestamp) -> None
216215
if not isinstance(timestamp, Timestamp):
217216
raise ValueError(
218217
'Input of observe_timestamp should be a Timestamp '

sdks/python/apache_beam/transforms/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def reset(self):
523523
class _WatermarkEstimatorParam(_DoFnParam):
524524
"""WatermarkEstomator DoFn parameter."""
525525
def __init__(self, watermark_estimator_provider):
526+
# type: (WatermarkEstimatorProvider) -> None
526527
if not isinstance(watermark_estimator_provider, WatermarkEstimatorProvider):
527528
raise ValueError(
528529
'DoFn._WatermarkEstimatorParam expected'

0 commit comments

Comments
 (0)