-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathsdf_utils_test.py
More file actions
138 lines (117 loc) · 5.74 KB
/
sdf_utils_test.py
File metadata and controls
138 lines (117 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for classes in sdf_utils.py."""
# pytype: skip-file
import time
import unittest
from apache_beam.io.concat_source_test import RangeSource
from apache_beam.io.restriction_trackers import OffsetRange
from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
from apache_beam.runners.sdf_utils import RestrictionTrackerView
from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
from apache_beam.utils import timestamp
class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
def test_initialization(self):
with self.assertRaises(ValueError):
ThreadsafeRestrictionTracker(RangeSource(0, 1))
def test_defer_remainder_with_wrong_time_type(self):
threadsafe_tracker = ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
with self.assertRaises(ValueError):
threadsafe_tracker.defer_remainder(10)
def test_self_checkpoint_immediately(self):
restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
threadsafe_tracker = ThreadsafeRestrictionTracker(restriction_tracker)
threadsafe_tracker.defer_remainder()
deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
expected_residual = OffsetRange(0, 10)
self.assertEqual(deferred_residual, expected_residual)
self.assertTrue(isinstance(deferred_time, timestamp.Duration))
self.assertEqual(deferred_time, 0)
def test_self_checkpoint_with_relative_time(self):
threadsafe_tracker = ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
threadsafe_tracker.defer_remainder(timestamp.Duration(100))
time.sleep(2)
_, deferred_time = threadsafe_tracker.deferred_status()
self.assertTrue(isinstance(deferred_time, timestamp.Duration))
# The expectation = 100 - 2 - some_delta
self.assertTrue(deferred_time <= 98)
def test_self_checkpoint_with_absolute_time(self):
threadsafe_tracker = ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
now = timestamp.Timestamp.now()
schedule_time = now + timestamp.Duration(100)
self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
threadsafe_tracker.defer_remainder(schedule_time)
time.sleep(2)
_, deferred_time = threadsafe_tracker.deferred_status()
self.assertTrue(isinstance(deferred_time, timestamp.Duration))
# The expectation =
# schedule_time - the time when deferred_status is called - some_delta
self.assertTrue(deferred_time <= 98)
class RestrictionTrackerViewTest(unittest.TestCase):
def test_initialization(self):
with self.assertRaises(ValueError):
RestrictionTrackerView(OffsetRestrictionTracker(OffsetRange(0, 10)))
def test_api_expose(self):
threadsafe_tracker = ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
tracker_view = RestrictionTrackerView(threadsafe_tracker)
current_restriction = tracker_view.current_restriction()
self.assertEqual(current_restriction, OffsetRange(0, 10))
self.assertTrue(tracker_view.try_claim(0))
tracker_view.defer_remainder()
deferred_remainder, deferred_watermark = (
threadsafe_tracker.deferred_status())
self.assertEqual(deferred_remainder, OffsetRange(1, 10))
self.assertEqual(deferred_watermark, timestamp.Duration())
def test_non_expose_apis(self):
threadsafe_tracker = ThreadsafeRestrictionTracker(
OffsetRestrictionTracker(OffsetRange(0, 10)))
tracker_view = RestrictionTrackerView(threadsafe_tracker)
with self.assertRaises(AttributeError):
tracker_view.check_done()
with self.assertRaises(AttributeError):
tracker_view.current_progress()
with self.assertRaises(AttributeError):
tracker_view.try_split()
with self.assertRaises(AttributeError):
tracker_view.deferred_status()
class ThreadsafeWatermarkEstimatorTest(unittest.TestCase):
def test_initialization(self):
with self.assertRaises(ValueError):
ThreadsafeWatermarkEstimator(None)
def test_get_estimator_state(self):
estimator = ThreadsafeWatermarkEstimator(ManualWatermarkEstimator(None))
self.assertIsNone(estimator.get_estimator_state())
estimator.set_watermark(timestamp.Timestamp(10))
self.assertEqual(estimator.get_estimator_state(), timestamp.Timestamp(10))
def test_track_timestamp(self):
estimator = ThreadsafeWatermarkEstimator(ManualWatermarkEstimator(None))
estimator.observe_timestamp(timestamp.Timestamp(10))
self.assertIsNone(estimator.current_watermark())
estimator.set_watermark(timestamp.Timestamp(20))
self.assertEqual(estimator.current_watermark(), timestamp.Timestamp(20))
def test_non_exsited_attr(self):
estimator = ThreadsafeWatermarkEstimator(ManualWatermarkEstimator(None))
with self.assertRaises(AttributeError):
estimator.non_existed_call()
if __name__ == '__main__':
unittest.main()