-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathstandard_coders_test.py
More file actions
331 lines (284 loc) · 12.8 KB
/
standard_coders_test.py
File metadata and controls
331 lines (284 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for coders that must be consistent across all Beam SDKs.
"""
# pytype: skip-file
import json
import logging
import math
import os.path
import sys
import unittest
from copy import deepcopy
from typing import Dict
from typing import Tuple
import numpy as np
import yaml
from numpy.testing import assert_array_equal
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import schema_pb2
from apache_beam.runners import pipeline_context
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.transforms.window import IntervalWindow
from apache_beam.typehints import schemas
from apache_beam.utils import windowed_value
from apache_beam.utils.sharded_key import ShardedKey
from apache_beam.utils.timestamp import Timestamp
from apache_beam.utils.windowed_value import PaneInfo
from apache_beam.utils.windowed_value import PaneInfoTiming
STANDARD_CODERS_YAML = os.path.normpath(
os.path.join(
os.path.dirname(__file__), '../portability/api/standard_coders.yaml'))
def _load_test_cases(test_yaml):
"""Load test data from yaml file and return an iterable of test cases.
See ``standard_coders.yaml`` for more details.
"""
if not os.path.exists(test_yaml):
raise ValueError('Could not find the test spec: %s' % test_yaml)
with open(test_yaml, 'rb') as coder_spec:
for ix, spec in enumerate(
yaml.load_all(coder_spec, Loader=yaml.SafeLoader)):
spec['index'] = ix
name = spec.get('name', spec['coder']['urn'].split(':')[-2])
yield [name, spec]
def parse_float(s):
x = float(s)
if math.isnan(x):
# In Windows, float('NaN') has opposite sign from other platforms.
# For the purpose of this test, we just need consistency.
x = abs(x)
return x
def value_parser_from_schema(schema):
def attribute_parser_from_type(type_):
parser = nonnull_attribute_parser_from_type(type_)
if type_.nullable:
return lambda x: None if x is None else parser(x)
else:
return parser
def nonnull_attribute_parser_from_type(type_):
# TODO: This should be exhaustive
type_info = type_.WhichOneof("type_info")
if type_info == "atomic_type":
if type_.atomic_type == schema_pb2.BYTES:
return lambda x: x.encode("utf-8")
else:
return schemas.ATOMIC_TYPE_TO_PRIMITIVE[type_.atomic_type]
elif type_info == "array_type":
element_parser = attribute_parser_from_type(type_.array_type.element_type)
return lambda x: list(map(element_parser, x))
elif type_info == "map_type":
key_parser = attribute_parser_from_type(type_.map_type.key_type)
value_parser = attribute_parser_from_type(type_.map_type.value_type)
return lambda x: dict(
(key_parser(k), value_parser(v)) for k, v in x.items())
elif type_info == "row_type":
return value_parser_from_schema(type_.row_type.schema)
elif type_info == "logical_type":
# In YAML logical types are represented with their representation types.
to_language_type = schemas.LogicalType.from_runner_api(
type_.logical_type).to_language_type
parse_representation = attribute_parser_from_type(
type_.logical_type.representation)
return lambda x: to_language_type(parse_representation(x))
parsers = [(field.name, attribute_parser_from_type(field.type))
for field in schema.fields]
constructor = schemas.named_tuple_from_schema(schema)
def value_parser(x):
result = []
x = deepcopy(x)
for name, parser in parsers:
value = x.pop(name)
result.append(None if value is None else parser(value))
if len(x):
raise ValueError(
"Test data contains attributes that don't exist in the schema: {}".
format(', '.join(x.keys())))
return constructor(*result)
return value_parser
class StandardCodersTest(unittest.TestCase):
_urn_to_json_value_parser = {
'beam:coder:bytes:v1': lambda x: x.encode('utf-8'),
'beam:coder:bool:v1': lambda x: x,
'beam:coder:string_utf8:v1': lambda x: x,
'beam:coder:varint:v1': lambda x: x,
'beam:coder:kv:v1': lambda x, key_parser, value_parser:
(key_parser(x['key']), value_parser(x['value'])),
'beam:coder:interval_window:v1': lambda x: IntervalWindow(
start=Timestamp(micros=(x['end'] - x['span']) * 1000), end=Timestamp(
micros=x['end'] * 1000)),
'beam:coder:iterable:v1': lambda x, parser: list(map(parser, x)),
'beam:coder:state_backed_iterable:v1': lambda x, parser: list(
map(parser, x)),
'beam:coder:global_window:v1': lambda x: window.GlobalWindow(),
'beam:coder:windowed_value:v1': lambda x, value_parser, window_parser:
windowed_value.create(
value_parser(x['value']), x['timestamp'] * 1000, tuple(
window_parser(w) for w in x['windows'])),
'beam:coder:param_windowed_value:v1': lambda x, value_parser,
window_parser: windowed_value.create(
value_parser(x['value']), x['timestamp'] * 1000, tuple(
window_parser(w) for w in x['windows']), PaneInfo(
x['pane']['is_first'], x['pane']['is_last'], PaneInfoTiming.
from_string(x['pane']['timing']), x['pane']['index'], x[
'pane']['on_time_index'])),
'beam:coder:timer:v1': lambda x, value_parser, window_parser: userstate.
Timer(
user_key=value_parser(x['userKey']), dynamic_timer_tag=x[
'dynamicTimerTag'], clear_bit=x['clearBit'], windows=tuple(
window_parser(w) for w in x['windows']), fire_timestamp=None,
hold_timestamp=None, paneinfo=None)
if x['clearBit'] else userstate.Timer(
user_key=value_parser(x['userKey']), dynamic_timer_tag=x[
'dynamicTimerTag'], clear_bit=x['clearBit'], fire_timestamp=
Timestamp(micros=x['fireTimestamp'] * 1000), hold_timestamp=Timestamp(
micros=x['holdTimestamp'] * 1000), windows=tuple(
window_parser(w) for w in x['windows']), paneinfo=PaneInfo(
x['pane']['is_first'], x['pane']['is_last'],
PaneInfoTiming.from_string(x['pane']['timing']), x[
'pane']['index'], x['pane']['on_time_index'])),
'beam:coder:double:v1': parse_float,
'beam:coder:sharded_key:v1': lambda x, value_parser: ShardedKey(
key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
'beam:coder:custom_window:v1': lambda x, window_parser: window_parser(
x['window']),
'beam:coder:nullable:v1': lambda x, value_parser: x.encode('utf-8')
if x else None
}
def test_standard_coders(self):
for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
logging.info('Executing %s test.', name)
self._run_standard_coder(name, spec)
def _run_standard_coder(self, name, spec):
def assert_equal(actual, expected):
"""Handle nan values which self.assertEqual fails on."""
if (isinstance(actual, float) and isinstance(expected, float) and
math.isnan(actual) and math.isnan(expected)):
return
self.assertEqual(actual, expected)
coder = self.parse_coder(spec['coder'])
parse_value = self.json_value_parser(spec['coder'])
nested_list = [spec['nested']] if 'nested' in spec else [True, False]
for nested in nested_list:
for expected_encoded, json_value in spec['examples'].items():
value = parse_value(json_value)
expected_encoded = expected_encoded.encode('latin1')
if not spec['coder'].get('non_deterministic', False):
actual_encoded = encode_nested(coder, value, nested)
if self.fix and actual_encoded != expected_encoded:
self.to_fix[spec['index'], expected_encoded] = actual_encoded
else:
self.assertEqual(expected_encoded, actual_encoded)
decoded = decode_nested(coder, expected_encoded, nested)
assert_equal(decoded, value)
else:
# Only verify decoding for a non-deterministic coder
self.assertEqual(
decode_nested(coder, expected_encoded, nested), value)
if spec['coder']['urn'] == 'beam:coder:row:v1':
# Test batch encoding/decoding as well.
values = [
parse_value(json_value) for json_value in spec['examples'].values()
]
columnar = {
field.name: np.array([getattr(value, field.name) for value in values])
for field in coder.schema.fields
}
dest = {
field: np.empty_like(values)
for field, values in columnar.items()
}
for column in dest.values():
column[:] = 0 if 'int' in column.dtype.name else None
expected_encoded = ''.join(spec['examples'].keys()).encode('latin1')
actual_encoded = encode_batch(coder, columnar)
assert_equal(expected_encoded, actual_encoded)
decoded_count = decode_batch(coder, expected_encoded, dest)
assert_equal(len(spec['examples']), decoded_count)
for field, values in dest.items():
assert_array_equal(columnar[field], dest[field])
def parse_coder(self, spec):
context = pipeline_context.PipelineContext()
coder_id = str(hash(str(spec)))
component_ids = [
context.coders.get_id(self.parse_coder(c))
for c in spec.get('components', ())
]
if spec.get('state'):
def iterable_state_read(state_token, elem_coder):
state = spec.get('state').get(state_token.decode('latin1'))
if state is None:
state = ''
input_stream = coder_impl.create_InputStream(state.encode('latin1'))
while input_stream.size() > 0:
yield elem_coder.decode_from_stream(input_stream, True)
context.iterable_state_read = iterable_state_read
context.coders.put_proto(
coder_id,
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=spec['urn'],
payload=spec.get('payload', '').encode('latin1')),
component_coder_ids=component_ids))
return context.coders.get_by_id(coder_id)
def json_value_parser(self, coder_spec):
# TODO: integrate this with the logic for the other parsers
if coder_spec['urn'] == 'beam:coder:row:v1':
schema = schema_pb2.Schema.FromString(
coder_spec['payload'].encode('latin1'))
return value_parser_from_schema(schema)
component_parsers = [
self.json_value_parser(c) for c in coder_spec.get('components', ())
]
return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
x, *component_parsers)
# Used when --fix is passed.
fix = False
to_fix: Dict[Tuple[int, bytes], bytes] = {}
@classmethod
def tearDownClass(cls):
if cls.fix and cls.to_fix:
print("FIXING", len(cls.to_fix), "TESTS")
doc_sep = '\n---\n'
docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)
def quote(s):
return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')
for (doc_ix, expected_encoded), actual_encoded in cls.to_fix.items():
print(quote(expected_encoded), "->", quote(actual_encoded))
docs[doc_ix] = docs[doc_ix].replace(
quote(expected_encoded) + ':', quote(actual_encoded) + ':')
open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
def encode_nested(coder, value, nested=True):
out = coder_impl.create_OutputStream()
coder.get_impl().encode_to_stream(value, out, nested)
return out.get()
def decode_nested(coder, encoded, nested=True):
return coder.get_impl().decode_from_stream(
coder_impl.create_InputStream(encoded), nested)
def encode_batch(row_coder, values):
out = coder_impl.create_OutputStream()
row_coder.get_impl().encode_batch_to_stream(values, out)
return out.get()
def decode_batch(row_coder, encoded, dest):
return row_coder.get_impl().decode_batch_from_stream(
dest, coder_impl.create_InputStream(encoded))
if __name__ == '__main__':
if '--fix' in sys.argv:
StandardCodersTest.fix = True
sys.argv.remove('--fix')
unittest.main()