-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathcreate_test.py
More file actions
168 lines (138 loc) · 6.7 KB
/
create_test.py
File metadata and controls
168 lines (138 loc) · 6.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for the Create and _CreateSource classes."""
# pytype: skip-file
import logging
import unittest
from apache_beam import Create
from apache_beam import coders
from apache_beam.coders import FastPrimitivesCoder
from apache_beam.internal import pickler
from apache_beam.io import source_test_utils
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
class CreateTest(unittest.TestCase):
def setUp(self):
self.coder = FastPrimitivesCoder()
def test_create_transform(self):
with TestPipeline() as p:
assert_that(p | 'Empty' >> Create([]), equal_to([]), label='empty')
assert_that(p | 'One' >> Create([None]), equal_to([None]), label='one')
assert_that(p | Create(list(range(10))), equal_to(list(range(10))))
def test_create_source_read(self):
self.check_read([], self.coder)
self.check_read([1], self.coder)
# multiple values.
self.check_read(list(range(10)), self.coder)
def check_read(self, values, coder):
source = Create._create_source_from_iterable(values, coder)
read_values = source_test_utils.read_from_source(source)
self.assertEqual(sorted(values), sorted(read_values))
def test_create_source_read_with_initial_splits(self):
self.check_read_with_initial_splits([], self.coder, num_splits=2)
self.check_read_with_initial_splits([1], self.coder, num_splits=2)
values = list(range(8))
# multiple values with a single split.
self.check_read_with_initial_splits(values, self.coder, num_splits=1)
# multiple values with a single split with a large desired bundle size
self.check_read_with_initial_splits(values, self.coder, num_splits=0.5)
# multiple values with many splits.
self.check_read_with_initial_splits(values, self.coder, num_splits=3)
# multiple values with uneven sized splits.
self.check_read_with_initial_splits(values, self.coder, num_splits=4)
# multiple values with num splits equal to num values.
self.check_read_with_initial_splits(
values, self.coder, num_splits=len(values))
# multiple values with num splits greater than to num values.
self.check_read_with_initial_splits(values, self.coder, num_splits=30)
def check_read_with_initial_splits(self, values, coder, num_splits):
"""A test that splits the given source into `num_splits` and verifies that
the data read from original source is equal to the union of the data read
from the split sources.
"""
source = Create._create_source_from_iterable(values, coder)
desired_bundle_size = source._total_size // num_splits
splits = source.split(desired_bundle_size)
splits_info = [(split.source, split.start_position, split.stop_position)
for split in splits]
source_test_utils.assert_sources_equal_reference_source(
(source, None, None), splits_info)
def test_create_source_read_reentrant(self):
source = Create._create_source_from_iterable(range(9), self.coder)
source_test_utils.assert_reentrant_reads_succeed((source, None, None))
def test_create_source_read_reentrant_with_initial_splits(self):
source = Create._create_source_from_iterable(range(24), self.coder)
for split in source.split(desired_bundle_size=5):
source_test_utils.assert_reentrant_reads_succeed(
(split.source, split.start_position, split.stop_position))
def test_create_source_dynamic_splitting(self):
# 2 values
source = Create._create_source_from_iterable(range(2), self.coder)
source_test_utils.assert_split_at_fraction_exhaustive(source)
# Multiple values.
source = Create._create_source_from_iterable(range(11), self.coder)
source_test_utils.assert_split_at_fraction_exhaustive(
source, perform_multi_threaded_test=True)
def test_create_source_progress(self):
num_values = 10
source = Create._create_source_from_iterable(range(num_values), self.coder)
splits = [split for split in source.split(desired_bundle_size=100)]
assert len(splits) == 1
fraction_consumed_report = []
split_points_report = []
range_tracker = splits[0].source.get_range_tracker(
splits[0].start_position, splits[0].stop_position)
for _ in splits[0].source.read(range_tracker):
fraction_consumed_report.append(range_tracker.fraction_consumed())
split_points_report.append(range_tracker.split_points())
self.assertEqual([float(i) / num_values for i in range(num_values)],
fraction_consumed_report)
expected_split_points_report = [((i - 1), num_values - (i - 1))
for i in range(1, num_values + 1)]
self.assertEqual(expected_split_points_report, split_points_report)
def test_create_uses_coder_for_pickling(self):
coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
create = Create([_Unpicklable(1), _Unpicklable(2), _Unpicklable(3)])
unpickled_create = pickler.loads(pickler.dumps(create))
self.assertEqual(
sorted(create.values, key=lambda v: v.value),
sorted(unpickled_create.values, key=lambda v: v.value))
with self.assertRaises(NotImplementedError):
# As there is no special coder for Union types, this will fall back to
# FastPrimitivesCoder, which in turn falls back to pickling.
create_mixed_types = Create([_Unpicklable(1), 2])
pickler.dumps(create_mixed_types)
class _Unpicklable(object):
def __init__(self, value):
self.value = value
def __eq__(self, other):
return self.value == other.value
def __getstate__(self):
raise NotImplementedError()
def __setstate__(self, state):
raise NotImplementedError()
class _UnpicklableCoder(coders.Coder):
def encode(self, value):
return str(value.value).encode()
def decode(self, encoded):
return _Unpicklable(int(encoded.decode()))
def to_type_hint(self):
return _Unpicklable
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()