# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # pytype: skip-file import binascii import glob import gzip import io import json import logging import os import pickle import random import re import shutil import tempfile import unittest import zlib from datetime import datetime import pytz import apache_beam as beam from apache_beam import Create from apache_beam import coders from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.tfrecordio import ReadAllFromTFRecord from apache_beam.io.tfrecordio import ReadFromTFRecord from apache_beam.io.tfrecordio import WriteToTFRecord from apache_beam.io.tfrecordio import _TFRecordSink from apache_beam.io.tfrecordio import _TFRecordUtil from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_utils import TempDir from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.util import LogElements try: import tensorflow.compat.v1 as tf # pylint: disable=import-error except ImportError: try: import tensorflow as tf # pylint: disable=import-error except ImportError: tf = None # pylint: disable=invalid-name logging.warning('Tensorflow is not installed, so skipping some tests.') try: import crcmod except ImportError: crcmod = None # Created by running following code in python: # >>> import tensorflow as tf # >>> import base64 # >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord') # >>> writer.write(b'foo') # >>> writer.close() # >>> with open('/tmp/python_foo.tfrecord', 'rb') as f: # ... data = base64.b64encode(f.read()) # ... print(data) FOO_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/g==' # Same as above but containing two records [b'foo', b'bar'] FOO_BAR_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=' def _write_file(path, base64_records): record = binascii.a2b_base64(base64_records) with open(path, 'wb') as f: f.write(record) def _write_file_deflate(path, base64_records): record = binascii.a2b_base64(base64_records) with open(path, 'wb') as f: f.write(zlib.compress(record)) def _write_file_gzip(path, base64_records): record = binascii.a2b_base64(base64_records) with gzip.GzipFile(path, 'wb') as f: f.write(record) class TestTFRecordUtil(unittest.TestCase): def setUp(self): self.record = binascii.a2b_base64(FOO_RECORD_BASE64) def _as_file_handle(self, contents): result = io.BytesIO() result.write(contents) result.seek(0) return result def _increment_value_at_index(self, value, index): l = list(value) l[index] = l[index] + 1 return bytes(l) def _test_error(self, record, error_text): with self.assertRaisesRegex(ValueError, re.escape(error_text)): _TFRecordUtil.read_record(self._as_file_handle(record)) def test_masked_crc32c(self): self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32)) self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32)) self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo')) self.assertEqual( 0xe4999b0, _TFRecordUtil._masked_crc32c(b'\x03\x00\x00\x00\x00\x00\x00\x00')) @unittest.skipIf(crcmod is None, 'crcmod not installed.') def test_masked_crc32c_crcmod(self): crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') self.assertEqual( 0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32, crc32c_fn=crc32c_fn)) self.assertEqual( 0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32, crc32c_fn=crc32c_fn)) self.assertEqual( 0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo', crc32c_fn=crc32c_fn)) self.assertEqual( 0xe4999b0, _TFRecordUtil._masked_crc32c( b'\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn)) def test_write_record(self): file_handle = io.BytesIO() _TFRecordUtil.write_record(file_handle, b'foo') self.assertEqual(self.record, file_handle.getvalue()) def test_read_record(self): actual = _TFRecordUtil.read_record(self._as_file_handle(self.record)) self.assertEqual(b'foo', actual) def test_read_record_invalid_record(self): self._test_error(b'bar', 'Not a valid TFRecord. Fewer than 12 bytes') def test_read_record_invalid_length_mask(self): record = self._increment_value_at_index(self.record, 9) self._test_error(record, 'Mismatch of length mask') def test_read_record_invalid_data_mask(self): record = self._increment_value_at_index(self.record, 16) self._test_error(record, 'Mismatch of data mask') def test_compatibility_read_write(self): for record in [b'', b'blah', b'another blah']: file_handle = io.BytesIO() _TFRecordUtil.write_record(file_handle, record) file_handle.seek(0) actual = _TFRecordUtil.read_record(file_handle) self.assertEqual(record, actual) class TestTFRecordSink(unittest.TestCase): def _write_lines(self, sink, path, lines): f = sink.open(path) for l in lines: sink.write_record(f, l) sink.close(f) def test_write_record_single(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') record = binascii.a2b_base64(FOO_RECORD_BASE64) sink = _TFRecordSink( path, coder=coders.BytesCoder(), file_name_suffix='', num_shards=0, shard_name_template=None, compression_type=CompressionTypes.UNCOMPRESSED) self._write_lines(sink, path, [b'foo']) with open(path, 'rb') as f: self.assertEqual(f.read(), record) def test_write_record_multiple(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) sink = _TFRecordSink( path, coder=coders.BytesCoder(), file_name_suffix='', num_shards=0, shard_name_template=None, compression_type=CompressionTypes.UNCOMPRESSED) self._write_lines(sink, path, [b'foo', b'bar']) with open(path, 'rb') as f: self.assertEqual(f.read(), record) @unittest.skipIf(tf is None, 'tensorflow not installed.') class TestWriteToTFRecord(TestTFRecordSink): def test_write_record_gzip(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') with TestPipeline() as p: input_data = [b'foo', b'bar'] _ = p | beam.Create(input_data) | WriteToTFRecord( file_path_prefix, compression_type=CompressionTypes.GZIP) actual = [] file_name = glob.glob(file_path_prefix + '-*')[0] for r in tf.python_io.tf_record_iterator( file_name, options=tf.python_io.TFRecordOptions( tf.python_io.TFRecordCompressionType.GZIP)): actual.append(r) self.assertEqual(sorted(actual), sorted(input_data)) def test_write_record_auto(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') with TestPipeline() as p: input_data = [b'foo', b'bar'] _ = p | beam.Create(input_data) | WriteToTFRecord( file_path_prefix, file_name_suffix='.gz') actual = [] file_name = glob.glob(file_path_prefix + '-*.gz')[0] for r in tf.python_io.tf_record_iterator( file_name, options=tf.python_io.TFRecordOptions( tf.python_io.TFRecordCompressionType.GZIP)): actual.append(r) self.assertEqual(sorted(actual), sorted(input_data)) class TestReadFromTFRecord(unittest.TestCase): def test_process_single(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file(path, FOO_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord( path, coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO, validate=True)) assert_that(result, equal_to([b'foo'])) def test_process_multiple(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord( path, coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO, validate=True)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_deflate(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file_deflate(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord( path, coder=coders.BytesCoder(), compression_type=CompressionTypes.DEFLATE, validate=True)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_gzip_with_coder(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file_gzip(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord( path, coder=coders.BytesCoder(), compression_type=CompressionTypes.GZIP, validate=True)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_gzip_without_coder(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file_gzip(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord(path, compression_type=CompressionTypes.GZIP)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_auto(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result.gz') _write_file_gzip(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord( path, coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO, validate=True)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_gzip_auto(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result.gz') _write_file_gzip(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | ReadFromTFRecord(path, compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo', b'bar'])) class TestReadAllFromTFRecord(unittest.TestCase): def _write_glob(self, temp_dir, suffix, include_empty=False): for _ in range(3): path = temp_dir.create_temp_file(suffix) _write_file(path, FOO_BAR_RECORD_BASE64) if include_empty: path = temp_dir.create_temp_file(suffix) _write_file(path, '') def test_process_single(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file(path, FOO_RECORD_BASE64) with TestPipeline() as p: result = ( p | Create([path]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo'])) def test_process_multiple(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | Create([path]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_with_filename(self): with TempDir() as temp_dir: num_files = 3 files = [] for i in range(num_files): path = temp_dir.create_temp_file('result%s' % i) _write_file(path, FOO_BAR_RECORD_BASE64) files.append(path) content = [b'foo', b'bar'] expected = [(file, line) for file in files for line in content] pattern = temp_dir.get_path() + os.path.sep + '*' with TestPipeline() as p: result = ( p | Create([pattern]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO, with_filename=True)) assert_that(result, equal_to(expected)) def test_process_glob(self): with TempDir() as temp_dir: self._write_glob(temp_dir, 'result') glob = temp_dir.get_path() + os.path.sep + '*result' with TestPipeline() as p: result = ( p | Create([glob]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo', b'bar'] * 3)) def test_process_glob_with_empty_file(self): with TempDir() as temp_dir: self._write_glob(temp_dir, 'result', include_empty=True) glob = temp_dir.get_path() + os.path.sep + '*result' with TestPipeline() as p: result = ( p | Create([glob]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo', b'bar'] * 3)) def test_process_multiple_globs(self): with TempDir() as temp_dir: globs = [] for i in range(3): suffix = 'result' + str(i) self._write_glob(temp_dir, suffix) globs.append(temp_dir.get_path() + os.path.sep + '*' + suffix) with TestPipeline() as p: result = ( p | Create(globs) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo', b'bar'] * 9)) def test_process_deflate(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file_deflate(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | Create([path]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.DEFLATE)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_gzip(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') _write_file_gzip(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | Create([path]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.GZIP)) assert_that(result, equal_to([b'foo', b'bar'])) def test_process_auto(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result.gz') _write_file_gzip(path, FOO_BAR_RECORD_BASE64) with TestPipeline() as p: result = ( p | Create([path]) | ReadAllFromTFRecord( coder=coders.BytesCoder(), compression_type=CompressionTypes.AUTO)) assert_that(result, equal_to([b'foo', b'bar'])) class TestEnd2EndWriteAndRead(unittest.TestCase): def create_inputs(self): input_array = [[random.random() - 0.5 for _ in range(15)] for _ in range(12)] memfile = io.BytesIO() pickle.dump(input_array, memfile) return memfile.getvalue() def test_end2end(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') # Generate a TFRecord file. with TestPipeline() as p: expected_data = [self.create_inputs() for _ in range(0, 10)] _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix) # Read the file back and compare. with TestPipeline() as p: actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') assert_that(actual_data, equal_to(expected_data)) def test_end2end_auto_compression(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') # Generate a TFRecord file. with TestPipeline() as p: expected_data = [self.create_inputs() for _ in range(0, 10)] _ = p | beam.Create(expected_data) | WriteToTFRecord( file_path_prefix, file_name_suffix='.gz') # Read the file back and compare. with TestPipeline() as p: actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') assert_that(actual_data, equal_to(expected_data)) def test_end2end_auto_compression_unsharded(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') # Generate a TFRecord file. with TestPipeline() as p: expected_data = [self.create_inputs() for _ in range(0, 10)] _ = p | beam.Create(expected_data) | WriteToTFRecord( file_path_prefix + '.gz', shard_name_template='') # Read the file back and compare. with TestPipeline() as p: actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz') assert_that(actual_data, equal_to(expected_data)) @unittest.skipIf(tf is None, 'tensorflow not installed.') def test_end2end_example_proto(self): with TempDir() as temp_dir: file_path_prefix = temp_dir.create_temp_file('result') example = tf.train.Example() example.features.feature['int'].int64_list.value.extend(list(range(3))) example.features.feature['bytes'].bytes_list.value.extend( [b'foo', b'bar']) with TestPipeline() as p: _ = p | beam.Create([example]) | WriteToTFRecord( file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__)) # Read the file back and compare. with TestPipeline() as p: actual_data = ( p | ReadFromTFRecord( file_path_prefix + '-*', coder=beam.coders.ProtoCoder(example.__class__))) assert_that(actual_data, equal_to([example])) def test_end2end_read_write_read(self): with TempDir() as temp_dir: path = temp_dir.create_temp_file('result') with TestPipeline() as p: # Initial read to validate the pipeline doesn't fail before the file is # created. _ = p | ReadFromTFRecord(path + '-*', validate=False) expected_data = [self.create_inputs() for _ in range(0, 10)] _ = p | beam.Create(expected_data) | WriteToTFRecord( path, file_name_suffix='.gz') # Read the file back and compare. with TestPipeline() as p: actual_data = p | ReadFromTFRecord(path + '-*', validate=True) assert_that(actual_data, equal_to(expected_data)) class GenerateEvent(beam.PTransform): @staticmethod def sample_data(): return GenerateEvent() def expand(self, input): elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] elem = elemlist return ( input | TestStream().add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 1, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 2, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 3, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 4, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 5, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 5, 0, tzinfo=pytz.UTC).timestamp()). add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 6, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 7, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 8, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 9, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 10, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 10, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 11, 0, tzinfo=pytz.UTC).timestamp()). add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 12, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 13, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 14, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 15, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 15, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 16, 0, tzinfo=pytz.UTC).timestamp()). add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 17, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 18, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 19, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 20, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 20, 0, tzinfo=pytz.UTC).timestamp()).advance_watermark_to( datetime( 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). timestamp()).advance_watermark_to_infinity()) class WriteStreamingTest(unittest.TestCase): def setUp(self): super().setUp() self.tempdir = tempfile.mkdtemp() def tearDown(self): if os.path.exists(self.tempdir): shutil.rmtree(self.tempdir) def test_write_streaming_2_shards_default_shard_name_template( self, num_shards=2): with TestPipeline() as p: output = ( p | GenerateEvent.sample_data() | 'User windowing' >> beam.transforms.core.WindowInto( beam.transforms.window.FixedWindows(60), trigger=beam.transforms.trigger.AfterWatermark(), accumulation_mode=beam.transforms.trigger.AccumulationMode. DISCARDING, allowed_lateness=beam.utils.timestamp.Duration(seconds=0)) | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) #TFrecordIO output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", file_name_suffix=".tfrecord", num_shards=num_shards, ) _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToTFRecord-[1614556800.0, 1614556805.0)-00000-of-00002.tfrecord # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P[\d\.]+), ' r'(?P[\d\.]+|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) self.assertEqual( len(file_names), num_shards, "expected %d files, but got: %d" % (num_shards, len(file_names))) def test_write_streaming_2_shards_custom_shard_name_template( self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): with TestPipeline() as p: output = ( p | GenerateEvent.sample_data() | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) #TFrecordIO output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", file_name_suffix=".tfrecord", shard_name_template=shard_name_template, num_shards=num_shards, triggering_frequency=60, ) _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- # 00000-of-00002.tfrecord # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) self.assertEqual( len(file_names), num_shards, "expected %d files, but got: %d" % (num_shards, len(file_names))) def test_write_streaming_2_shards_custom_shard_name_template_5s_window( self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN', triggering_frequency=5): with TestPipeline() as p: output = ( p | GenerateEvent.sample_data() | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) #TFrecordIO output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", file_name_suffix=".tfrecord", shard_name_template=shard_name_template, num_shards=num_shards, triggering_frequency=triggering_frequency, ) _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- # 00000-of-00002.tfrecord # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) # for 5s window size, the input should be processed by 5 windows with # 2 shards per window self.assertEqual( len(file_names), 10, "expected %d files, but got: %d" % (num_shards, len(file_names))) if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()