# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # pytype: skip-file import glob import json import logging import os import re import shutil import tempfile import unittest from datetime import datetime from tempfile import TemporaryDirectory import hamcrest as hc import pandas import pytest import pytz from parameterized import param from parameterized import parameterized import apache_beam as beam from apache_beam import Create from apache_beam import Map from apache_beam.io import filebasedsource from apache_beam.io import source_test_utils from apache_beam.io.iobase import RangeTracker from apache_beam.io.parquetio import ReadAllFromParquet from apache_beam.io.parquetio import ReadAllFromParquetBatched from apache_beam.io.parquetio import ReadFromParquet from apache_beam.io.parquetio import ReadFromParquetBatched from apache_beam.io.parquetio import WriteToParquet from apache_beam.io.parquetio import WriteToParquetBatched from apache_beam.io.parquetio import _create_parquet_sink from apache_beam.io.parquetio import _create_parquet_source from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.util import LogElements try: import pyarrow as pa import pyarrow.parquet as pq ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.')) except ImportError: pa = None pq = None ARROW_MAJOR_VERSION = 0 @unittest.skipIf(pa is None, "PyArrow is not installed.") @pytest.mark.uses_pyarrow class TestParquet(unittest.TestCase): def setUp(self): # Reducing the size of thread pools. Without this test execution may fail in # environments with limited amount of resources. filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 self.temp_dir = tempfile.mkdtemp() self.RECORDS = [{ 'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue' }, { 'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green' }, { 'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown' }, { 'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue' }, { 'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red' }, { 'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green' }, { 'name': 'Peter', 'favorite_number': 3, 'favorite_color': None }] self.SCHEMA = pa.schema([('name', pa.string(), False), ('favorite_number', pa.int64(), False), ('favorite_color', pa.string())]) self.SCHEMA96 = pa.schema([('name', pa.string(), False), ('favorite_number', pa.timestamp('ns'), False), ('favorite_color', pa.string())]) self.RECORDS_NESTED = [{ 'items': [ { 'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue' }, { 'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green' }, ] }, { 'items': [ { 'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown' }, ] }] self.SCHEMA_NESTED = pa.schema([( 'items', pa.list_( pa.struct([('name', pa.string(), False), ('favorite_number', pa.int64(), False), ('favorite_color', pa.string())])))]) def tearDown(self): shutil.rmtree(self.temp_dir) def _record_to_columns(self, records, schema): col_list = [] for n in schema.names: column = [] for r in records: column.append(r[n]) col_list.append(column) return col_list def _records_as_arrow(self, schema=None, count=None): if schema is None: schema = self.SCHEMA if count is None: count = len(self.RECORDS) len_records = len(self.RECORDS) data = [] for i in range(count): data.append(self.RECORDS[i % len_records]) col_data = self._record_to_columns(data, schema) col_array = [pa.array(c, schema.types[cn]) for cn, c in enumerate(col_data)] return pa.Table.from_arrays(col_array, schema=schema) def _write_data( self, directory=None, schema=None, prefix=tempfile.template, row_group_size=1000, codec='none', count=None): if directory is None: directory = self.temp_dir with tempfile.NamedTemporaryFile(delete=False, dir=directory, prefix=prefix) as f: table = self._records_as_arrow(schema, count) pq.write_table( table, f, row_group_size=row_group_size, compression=codec, use_deprecated_int96_timestamps=True) return f.name def _write_pattern(self, num_files, with_filename=False): assert num_files > 0 temp_dir = tempfile.mkdtemp(dir=self.temp_dir) file_list = [] for _ in range(num_files): file_list.append(self._write_data(directory=temp_dir, prefix='mytemp')) if with_filename: return (temp_dir + os.path.sep + 'mytemp*', file_list) return temp_dir + os.path.sep + 'mytemp*' def _run_parquet_test( self, pattern, columns, desired_bundle_size, perform_splitting, expected_result): source = _create_parquet_source(pattern, columns=columns) if perform_splitting: assert desired_bundle_size sources_info = [ (split.source, split.start_position, split.stop_position) for split in source.split(desired_bundle_size=desired_bundle_size) ] if len(sources_info) < 2: raise ValueError( 'Test is trivial. Please adjust it so that at least ' 'two splits get generated') source_test_utils.assert_sources_equal_reference_source( (source, None, None), sources_info) else: read_records = source_test_utils.read_from_source(source, None, None) self.assertCountEqual(expected_result, read_records) def test_read_without_splitting(self): file_name = self._write_data() expected_result = [self._records_as_arrow()] self._run_parquet_test(file_name, None, None, False, expected_result) def test_read_with_splitting(self): file_name = self._write_data() expected_result = [self._records_as_arrow()] self._run_parquet_test(file_name, None, 100, True, expected_result) def test_source_display_data(self): file_name = 'some_parquet_source' source = \ _create_parquet_source( file_name, validate=False ) dd = DisplayData.create_from(source) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), DisplayDataItemMatcher('file_pattern', file_name) ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_read_display_data(self): file_name = 'some_parquet_source' read = \ ReadFromParquet( file_name, validate=False) read_batched = \ ReadFromParquetBatched( file_name, validate=False) expected_items = [ DisplayDataItemMatcher('compression', 'auto'), DisplayDataItemMatcher('file_pattern', file_name) ] hc.assert_that( DisplayData.create_from(read).items, hc.contains_inanyorder(*expected_items)) hc.assert_that( DisplayData.create_from(read_batched).items, hc.contains_inanyorder(*expected_items)) def test_sink_display_data(self): file_name = 'some_parquet_sink' sink = _create_parquet_sink( file_name, self.SCHEMA, 'none', False, False, '.end', 0, None, 'application/x-parquet') dd = DisplayData.create_from(sink) expected_items = [ DisplayDataItemMatcher('schema', str(self.SCHEMA)), DisplayDataItemMatcher( 'file_pattern', 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'), DisplayDataItemMatcher('codec', 'none'), DisplayDataItemMatcher('compression', 'uncompressed') ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_write_display_data(self): file_name = 'some_parquet_sink' write = WriteToParquet(file_name, self.SCHEMA) dd = DisplayData.create_from(write) expected_items = [ DisplayDataItemMatcher('codec', 'none'), DisplayDataItemMatcher('schema', str(self.SCHEMA)), DisplayDataItemMatcher('row_group_buffer_size', str(64 * 1024 * 1024)), DisplayDataItemMatcher( 'file_pattern', 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'), DisplayDataItemMatcher('compression', 'uncompressed') ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_write_batched_display_data(self): file_name = 'some_parquet_sink' write = WriteToParquetBatched(file_name, self.SCHEMA) dd = DisplayData.create_from(write) expected_items = [ DisplayDataItemMatcher('codec', 'none'), DisplayDataItemMatcher('schema', str(self.SCHEMA)), DisplayDataItemMatcher( 'file_pattern', 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'), DisplayDataItemMatcher('compression', 'uncompressed') ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) @unittest.skipIf( ARROW_MAJOR_VERSION >= 13, 'pyarrow 13.x and above does not throw ArrowInvalid error') def test_sink_transform_int96(self): with self.assertRaisesRegex(Exception, 'would lose data'): # Should throw an error "ArrowInvalid: Casting from timestamp[ns] to # timestamp[us] would lose data" dst = tempfile.NamedTemporaryFile() path = dst.name with TestPipeline() as p: _ = p \ | Create(self.RECORDS) \ | WriteToParquet( path, self.SCHEMA96, num_shards=1, shard_name_template='') def test_sink_transform(self): with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + "tmp_filename") with TestPipeline() as p: _ = p \ | Create(self.RECORDS) \ | WriteToParquet( path, self.SCHEMA, num_shards=1, shard_name_template='') with TestPipeline() as p: # json used for stable sortability readback = \ p \ | ReadFromParquet(path) \ | Map(json.dumps) assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) def test_sink_transform_batched(self): with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + "tmp_filename") with TestPipeline() as p: _ = p \ | Create([self._records_as_arrow()]) \ | WriteToParquetBatched( path, self.SCHEMA, num_shards=1, shard_name_template='') with TestPipeline() as p: # json used for stable sortability readback = \ p \ | ReadFromParquet(path) \ | Map(json.dumps) assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) def test_sink_transform_compliant_nested_type(self): if ARROW_MAJOR_VERSION < 4: return unittest.skip( 'Writing with compliant nested type is only ' 'supported in pyarrow 4.x and above') with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + 'tmp_filename') with TestPipeline() as p: _ = p \ | Create(self.RECORDS_NESTED) \ | WriteToParquet( path, self.SCHEMA_NESTED, num_shards=1, shard_name_template='', use_compliant_nested_type=True) with TestPipeline() as p: # json used for stable sortability readback = \ p \ | ReadFromParquet(path) \ | Map(json.dumps) assert_that( readback, equal_to([json.dumps(r) for r in self.RECORDS_NESTED])) def test_schema_read_write(self): with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname, 'tmp_filename') rows = [beam.Row(a=1, b='x'), beam.Row(a=2, b='y')] stable_repr = lambda row: json.dumps(row._asdict()) with TestPipeline() as p: _ = p | Create(rows) | WriteToParquet(path) with TestPipeline() as p: readback = ( p | ReadFromParquet(path + '*', as_rows=True) | Map(stable_repr)) assert_that(readback, equal_to([stable_repr(r) for r in rows])) def test_write_with_nullable_fields_missing_data(self): """Test WriteToParquet with nullable fields where some fields are missing. This test addresses the bug reported in: https://github.com/apache/beam/issues/35791 where WriteToParquet fails with a KeyError if any nullable field is missing in the data. """ # Define PyArrow schema with all fields nullable schema = pa.schema([ pa.field("id", pa.int64(), nullable=True), pa.field("name", pa.string(), nullable=True), pa.field("age", pa.int64(), nullable=True), pa.field("email", pa.string(), nullable=True), ]) # Sample data with missing nullable fields data = [ { 'id': 1, 'name': 'Alice', 'age': 30 }, # missing 'email' { 'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com' }, # all fields present { 'id': 3, 'name': 'Charlie', 'age': None, 'email': None }, # explicit None values { 'id': 4, 'name': 'David' }, # missing 'age' and 'email' ] with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname, 'nullable_test') # Write data with missing nullable fields - this should not raise KeyError with TestPipeline() as p: _ = ( p | Create(data) | WriteToParquet( path, schema, num_shards=1, shard_name_template='')) # Read back and verify the data with TestPipeline() as p: readback = ( p | ReadFromParquet(path + '*') | Map(json.dumps, sort_keys=True)) # Expected data should have None for missing nullable fields expected_data = [ { 'id': 1, 'name': 'Alice', 'age': 30, 'email': None }, { 'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com' }, { 'id': 3, 'name': 'Charlie', 'age': None, 'email': None }, { 'id': 4, 'name': 'David', 'age': None, 'email': None }, ] assert_that( readback, equal_to([json.dumps(r, sort_keys=True) for r in expected_data])) def test_batched_read(self): with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + "tmp_filename") with TestPipeline() as p: _ = p \ | Create(self.RECORDS, reshuffle=False) \ | WriteToParquet( path, self.SCHEMA, num_shards=1, shard_name_template='') with TestPipeline() as p: # json used for stable sortability readback = \ p \ | ReadFromParquetBatched(path) assert_that(readback, equal_to([self._records_as_arrow()])) @parameterized.expand([ param(compression_type='snappy'), param(compression_type='gzip'), param(compression_type='brotli'), param(compression_type='lz4'), param(compression_type='zstd') ]) def test_sink_transform_compressed(self, compression_type): if compression_type == 'lz4' and ARROW_MAJOR_VERSION == 1: return unittest.skip( "Writing with LZ4 compression is not supported in " "pyarrow 1.x") with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + "tmp_filename") with TestPipeline() as p: _ = p \ | Create(self.RECORDS) \ | WriteToParquet( path, self.SCHEMA, codec=compression_type, num_shards=1, shard_name_template='') with TestPipeline() as p: # json used for stable sortability readback = \ p \ | ReadFromParquet(path + '*') \ | Map(json.dumps) assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) def test_read_reentrant(self): file_name = self._write_data(count=6, row_group_size=3) source = _create_parquet_source(file_name) source_test_utils.assert_reentrant_reads_succeed((source, None, None)) def test_read_without_splitting_multiple_row_group(self): file_name = self._write_data(count=12000, row_group_size=1000) # We expect 12000 elements, split into batches of 1000 elements. Create # a list of pa.Table instances to model this expecation expected_result = [ pa.Table.from_batches([batch]) for batch in self._records_as_arrow( count=12000).to_batches(max_chunksize=1000) ] self._run_parquet_test(file_name, None, None, False, expected_result) def test_read_with_splitting_multiple_row_group(self): file_name = self._write_data(count=12000, row_group_size=1000) # We expect 12000 elements, split into batches of 1000 elements. Create # a list of pa.Table instances to model this expecation expected_result = [ pa.Table.from_batches([batch]) for batch in self._records_as_arrow( count=12000).to_batches(max_chunksize=1000) ] self._run_parquet_test(file_name, None, 10000, True, expected_result) def test_dynamic_work_rebalancing(self): # This test depends on count being sufficiently large + the ratio of # count to row_group_size also being sufficiently large (but the required # ratio to pass varies for values of row_group_size and, somehow, the # version of pyarrow being tested against.) file_name = self._write_data(count=320, row_group_size=20) source = _create_parquet_source(file_name) splits = [split for split in source.split(desired_bundle_size=float('inf'))] assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position) def test_min_bundle_size(self): file_name = self._write_data(count=120, row_group_size=20) source = _create_parquet_source( file_name, min_bundle_size=100 * 1024 * 1024) splits = [split for split in source.split(desired_bundle_size=1)] self.assertEqual(len(splits), 1) source = _create_parquet_source(file_name, min_bundle_size=0) splits = [split for split in source.split(desired_bundle_size=1)] self.assertNotEqual(len(splits), 1) def _convert_to_timestamped_record(self, record): timestamped_record = record.copy() timestamped_record['favorite_number'] =\ pandas.Timestamp(timestamped_record['favorite_number']) return timestamped_record def test_int96_type_conversion(self): file_name = self._write_data( count=120, row_group_size=20, schema=self.SCHEMA96) orig = self._records_as_arrow(count=120, schema=self.SCHEMA96) expected_result = [ pa.Table.from_batches([batch], schema=self.SCHEMA96) for batch in orig.to_batches(max_chunksize=20) ] self._run_parquet_test(file_name, None, None, False, expected_result) def test_split_points(self): file_name = self._write_data(count=12000, row_group_size=3000) source = _create_parquet_source(file_name) splits = [split for split in source.split(desired_bundle_size=float('inf'))] assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) split_points_report = [] for _ in splits[0].source.read(range_tracker): split_points_report.append(range_tracker.split_points()) # There are a total of four row groups. Each row group has 3000 records. # When reading records of the first group, range_tracker.split_points() # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) self.assertEqual( split_points_report, [ (0, RangeTracker.SPLIT_POINTS_UNKNOWN), (1, RangeTracker.SPLIT_POINTS_UNKNOWN), (2, RangeTracker.SPLIT_POINTS_UNKNOWN), (3, 1), ]) def test_selective_columns(self): file_name = self._write_data() orig = self._records_as_arrow() name_column = self.SCHEMA.field('name') expected_result = [ pa.Table.from_arrays( [orig.column('name')], schema=pa.schema([('name', name_column.type, name_column.nullable) ])) ] self._run_parquet_test(file_name, ['name'], None, False, expected_result) def test_sink_transform_multiple_row_group(self): with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + "tmp_filename") # Pin to FnApiRunner since test assumes fixed bundle size with TestPipeline('FnApiRunner') as p: # writing 623200 bytes of data _ = p \ | Create(self.RECORDS * 4000) \ | WriteToParquet( path, self.SCHEMA, num_shards=1, codec='none', shard_name_template='', row_group_buffer_size=250000) self.assertEqual(pq.read_metadata(path).num_row_groups, 3) def test_read_all_from_parquet_single_file(self): path = self._write_data() with TestPipeline() as p: assert_that( p \ | Create([path]) \ | ReadAllFromParquet(), equal_to(self.RECORDS)) with TestPipeline() as p: assert_that( p \ | Create([path]) \ | ReadAllFromParquetBatched(), equal_to([self._records_as_arrow()])) def test_read_all_from_parquet_many_single_files(self): path1 = self._write_data() path2 = self._write_data() path3 = self._write_data() with TestPipeline() as p: assert_that( p \ | Create([path1, path2, path3]) \ | ReadAllFromParquet(), equal_to(self.RECORDS * 3)) with TestPipeline() as p: assert_that( p \ | Create([path1, path2, path3]) \ | ReadAllFromParquetBatched(), equal_to([self._records_as_arrow()] * 3)) def test_read_all_from_parquet_file_pattern(self): file_pattern = self._write_pattern(5) with TestPipeline() as p: assert_that( p \ | Create([file_pattern]) \ | ReadAllFromParquet(), equal_to(self.RECORDS * 5)) with TestPipeline() as p: assert_that( p \ | Create([file_pattern]) \ | ReadAllFromParquetBatched(), equal_to([self._records_as_arrow()] * 5)) def test_read_all_from_parquet_many_file_patterns(self): file_pattern1 = self._write_pattern(5) file_pattern2 = self._write_pattern(2) file_pattern3 = self._write_pattern(3) with TestPipeline() as p: assert_that( p \ | Create([file_pattern1, file_pattern2, file_pattern3]) \ | ReadAllFromParquet(), equal_to(self.RECORDS * 10)) with TestPipeline() as p: assert_that( p \ | Create([file_pattern1, file_pattern2, file_pattern3]) \ | ReadAllFromParquetBatched(), equal_to([self._records_as_arrow()] * 10)) def test_read_all_from_parquet_with_filename(self): file_pattern, file_paths = self._write_pattern(3, with_filename=True) result = [(path, record) for path in file_paths for record in self.RECORDS] with TestPipeline() as p: assert_that( p \ | Create([file_pattern]) \ | ReadAllFromParquet(with_filename=True), equal_to(result)) class GenerateEvent(beam.PTransform): @staticmethod def sample_data(): return GenerateEvent() def expand(self, input): elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] elem = elemlist return ( input | TestStream().add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 1, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 2, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 3, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 4, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 5, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 5, 0, tzinfo=pytz.UTC).timestamp()). add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 6, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 7, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 8, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 9, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 10, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 10, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 11, 0, tzinfo=pytz.UTC).timestamp()). add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 12, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 13, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 14, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 15, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 15, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 16, 0, tzinfo=pytz.UTC).timestamp()). add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 17, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 18, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 19, 0, tzinfo=pytz.UTC).timestamp()). advance_watermark_to( datetime(2021, 3, 1, 0, 0, 20, 0, tzinfo=pytz.UTC).timestamp()).add_elements( elements=elem, event_timestamp=datetime( 2021, 3, 1, 0, 0, 20, 0, tzinfo=pytz.UTC).timestamp()).advance_watermark_to( datetime( 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). timestamp()).advance_watermark_to_infinity()) class WriteStreamingTest(unittest.TestCase): def setUp(self): super().setUp() self.tempdir = tempfile.mkdtemp() def tearDown(self): if os.path.exists(self.tempdir): shutil.rmtree(self.tempdir) def test_write_streaming_2_shards_default_shard_name_template( self, num_shards=2): with TestPipeline() as p: output = (p | GenerateEvent.sample_data()) #ParquetIO pyschema = pa.schema([('age', pa.int64())]) output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( file_path_prefix=self.tempdir + "/ouput_WriteToParquet", file_name_suffix=".parquet", num_shards=num_shards, triggering_frequency=60, schema=pyschema) _ = output2 | 'LogElements after WriteToParquet' >> LogElements( prefix='after WriteToParquet ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P[\d\.]+), ' r'(?P[\d\.]+|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.parquet$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) self.assertEqual( len(file_names), num_shards, "expected %d files, but got: %d" % (num_shards, len(file_names))) def test_write_streaming_2_shards_custom_shard_name_template( self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): with TestPipeline() as p: output = (p | GenerateEvent.sample_data()) #ParquetIO pyschema = pa.schema([('age', pa.int64())]) output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( file_path_prefix=self.tempdir + "/ouput_WriteToParquet", file_name_suffix=".parquet", shard_name_template=shard_name_template, num_shards=num_shards, triggering_frequency=60, schema=pyschema) _ = output2 | 'LogElements after WriteToParquet' >> LogElements( prefix='after WriteToParquet ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- # 00000-of-00002.parquet # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.parquet$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) self.assertEqual( len(file_names), num_shards, "expected %d files, but got: %d" % (num_shards, len(file_names))) def test_write_streaming_2_shards_custom_shard_name_template_5s_window( self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN', triggering_frequency=5): with TestPipeline() as p: output = (p | GenerateEvent.sample_data()) #ParquetIO pyschema = pa.schema([('age', pa.int64())]) output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( file_path_prefix=self.tempdir + "/ouput_WriteToParquet", file_name_suffix=".parquet", shard_name_template=shard_name_template, num_shards=num_shards, triggering_frequency=triggering_frequency, schema=pyschema) _ = output2 | 'LogElements after WriteToParquet' >> LogElements( prefix='after WriteToParquet ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- # 00000-of-00002.parquet # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.parquet$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) # for 5s window size, the input should be processed by 5 windows with # 2 shards per window self.assertEqual( len(file_names), 10, "expected %d files, but got: %d" % (num_shards, len(file_names))) def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll( # pylint: disable=line-too-long self): with TestPipeline() as p: output = ( p | GenerateEvent.sample_data() | 'User windowing' >> beam.transforms.core.WindowInto( beam.transforms.window.FixedWindows(10), trigger=beam.transforms.trigger.AfterWatermark(), accumulation_mode=beam.transforms.trigger.AccumulationMode. DISCARDING, allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) #ParquetIO pyschema = pa.schema([('age', pa.int64())]) output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( file_path_prefix=self.tempdir + "/ouput_WriteToParquet", file_name_suffix=".parquet", num_shards=0, schema=pyschema) _ = output2 | 'LogElements after WriteToParquet' >> LogElements( prefix='after WriteToParquet ', with_window=True, level=logging.INFO) # Regex to match the expected windowed file pattern # Example: # ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet # It captures: window_interval, shard_num, total_shards pattern_string = ( r'.*-\[(?P[\d\.]+), ' r'(?P[\d\.]+|Infinity)\)-' r'(?P\d{5})-of-(?P\d{5})\.parquet$') pattern = re.compile(pattern_string) file_names = [] for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): match = pattern.match(file_name) self.assertIsNotNone( match, f"File name {file_name} did not match expected pattern.") if match: file_names.append(file_name) print("Found files matching expected pattern:", file_names) self.assertGreaterEqual( len(file_names), 1 * 3, #25s of data covered by 3 10s windows "expected %d files, but got: %d" % (1 * 3, len(file_names))) if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()