# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Runs the examples from the README.md file.""" import argparse import logging import os import random import re import sys import tempfile import unittest import mock import yaml from yaml.loader import SafeLoader import apache_beam as beam from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.typehints import trivial_inference from apache_beam.yaml import yaml_provider from apache_beam.yaml import yaml_testing from apache_beam.yaml import yaml_transform from apache_beam.yaml import yaml_utils class FakeSql(beam.PTransform): def __init__(self, query): self.query = query def default_label(self): return 'Sql' def expand(self, inputs): if isinstance(inputs, beam.PCollection): inputs = {'PCOLLECTION': inputs} # This only handles the most basic of queries, trying to infer the output # schema... m = re.match('select (.*?) from', self.query, flags=re.IGNORECASE) if not m: raise ValueError(self.query) def guess_name_and_type(expr): expr = expr.strip().replace('`', '') if expr.endswith('*'): return 'unknown', str parts = expr.split() if len(parts) >= 2 and parts[-2].lower() == 'as': name = parts[-1] elif re.match(r'[\w.]+', parts[0]): name = parts[0].split('.')[-1] else: name = f'expr{hash(expr)}' if '(' in expr: expr = expr.lower() if expr.startswith('count'): typ = int elif expr.startswith('avg'): typ = float else: typ = str elif '+' in expr: typ = float else: part = parts[0] if '.' in part: table, field = part.split('.') typ = inputs[table].element_type.get_type_for(field) else: typ = next(iter(inputs.values())).element_type.get_type_for(name) # Handle optionals more gracefully. if (str(typ).startswith('typing.Union[') or str(typ).startswith('typing.Optional[') or '|' in str(typ)): if len(typ.__args__) == 2 and type(None) in typ.__args__: typ, = [t for t in typ.__args__ if t is not type(None)] return name, typ if m.group(1) == '*': return next(iter(inputs.values())) | beam.Filter(lambda _: True) else: output_schema = [ guess_name_and_type(expr) for expr in m.group(1).split(',') ] output_element = beam.Row(**{name: typ() for name, typ in output_schema}) return next(iter(inputs.values())) | beam.Map( lambda _: output_element).with_output_types( trivial_inference.instance_to_type(output_element)) class FakeReadFromPubSub(beam.PTransform): def __init__(self, topic, format, schema): pass def expand(self, p): data = p | beam.Create([beam.Row(col1='a', col2=1, col3=0.5)]) result = data | beam.Map( lambda row: beam.transforms.window.TimestampedValue(row, 0)) # TODO(robertwb): Allow this to be inferred. result.element_type = data.element_type return result class FakeWriteToPubSub(beam.PTransform): def __init__(self, topic, format): pass def expand(self, pcoll): return pcoll class FakeAggregation(beam.PTransform): def __init__(self, **unused_kwargs): pass def expand(self, pcoll): return pcoll | beam.GroupBy(lambda _: 'key').aggregate_field( lambda _: 1, sum, 'count') class _Fakes: fn = str class SomeTransform(beam.PTransform): def __init__(*args, **kwargs): pass def expand(self, pcoll): return pcoll RENDER_DIR = None TEST_TRANSFORMS = { 'Sql': FakeSql, 'ReadFromPubSub': FakeReadFromPubSub, 'WriteToPubSub': FakeWriteToPubSub, 'SomeGroupingTransform': FakeAggregation, 'SomeTransform': _Fakes.SomeTransform, 'AnotherTransform': _Fakes.SomeTransform, } class TestProvider(yaml_provider.InlineProvider): def _affinity(self, other): # Always try to choose this one. return float('inf') class TestEnvironment: def __enter__(self): self.tempdir = tempfile.TemporaryDirectory() return self def input_file(self, name, content): path = os.path.join(self.tempdir.name, name) with open(path, 'w') as fout: fout.write(content) return path def input_csv(self): return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n') def input_tsv(self): return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n') def input_json(self): return self.input_file( 'input.json', '{"col1": "abc", "col2": 1, "col3": 2.5"}\n') def output_file(self): return os.path.join( self.tempdir.name, str(random.randint(0, 1000)) + '.out') def udf_file(self, name): if name == 'my_mapping': lines = '\n'.join(['def my_mapping(row):', '\treturn "good"']) else: lines = '\n'.join(['def my_filter(row):', '\treturn True']) return self.input_file('udf.py', lines) def __exit__(self, *args): self.tempdir.cleanup() def replace_recursive(spec, transform_type, arg_name, arg_value): if isinstance(spec, dict): spec = { key: replace_recursive(value, transform_type, arg_name, arg_value) for (key, value) in spec.items() } if spec.get('type', None) == transform_type: spec['config'][arg_name] = arg_value return spec elif isinstance(spec, list): return [ replace_recursive(value, transform_type, arg_name, arg_value) for value in spec ] else: return spec def create_test_method(test_type, test_name, test_yaml): test_yaml = test_yaml.replace( 'apache_beam.pkg.module.', 'apache_beam.yaml.readme_test._Fakes.') test_yaml = test_yaml.replace( 'pkg.module.', 'apache_beam.yaml.readme_test._Fakes.') def test(self): with TestEnvironment() as env: nonlocal test_yaml test_yaml = test_yaml.replace('/path/to/*.tsv', env.input_tsv()) if 'MapToFields' in test_yaml or 'Filter' in test_yaml: if 'my_mapping' in test_yaml: test_yaml = test_yaml.replace( '/path/to/some/udf.py', env.udf_file('my_mapping')) elif 'my_filter' in test_yaml: test_yaml = test_yaml.replace( '/path/to/some/udf.py', env.udf_file('my_filter')) spec = yaml.load(test_yaml, Loader=SafeLoader) if test_type == 'PARSE': return if 'ReadFromCsv' in test_yaml: spec = replace_recursive(spec, 'ReadFromCsv', 'path', env.input_csv()) if 'ReadFromText' in test_yaml: spec = replace_recursive(spec, 'ReadFromText', 'path', env.input_csv()) if 'ReadFromJson' in test_yaml: spec = replace_recursive(spec, 'ReadFromJson', 'path', env.input_json()) for write in ['WriteToText', 'WriteToCsv', 'WriteToJson']: if write in test_yaml: spec = replace_recursive(spec, write, 'path', env.output_file()) modified_yaml = yaml.dump(spec) options = {'pickle_library': 'cloudpickle'} if RENDER_DIR is not None: options['runner'] = 'apache_beam.runners.render.RenderRunner' options['render_output'] = [ os.path.join(RENDER_DIR, test_name + '.png') ] options['render_leaf_composite_nodes'] = ['.*'] test_provider = TestProvider(TEST_TRANSFORMS) with mock.patch( 'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider', lambda self: test_provider): # TODO(polber) - remove once there is support for ExternalTransforms # in precommits with mock.patch( 'apache_beam.yaml.yaml_provider.ExternalProvider.create_transform', lambda *args, **kwargs: _Fakes.SomeTransform(*args, **kwargs)): # Uses the FnApiRunner to ensure errors are mocked/passed through # correctly p = beam.Pipeline('FnApiRunner', options=PipelineOptions(**options)) yaml_transform.expand_pipeline( p, modified_yaml, yaml_provider.merge_providers([test_provider])) if test_type == 'BUILD': return p.run().wait_until_finish() return test def parse_test_methods(markdown_lines): # pylint: disable=too-many-nested-blocks def extract_inputs(input_spec): if not input_spec: return set() elif isinstance(input_spec, str): return set([input_spec.split('.')[0]]) elif isinstance(input_spec, list): return set.union(*[extract_inputs(v) for v in input_spec]) elif isinstance(input_spec, dict): return set.union(*[extract_inputs(v) for v in input_spec.values()]) else: raise ValueError("Misformed inputs: " + input_spec) def extract_name(input_spec): return input_spec.get('name', input_spec.get('type')) code_lines = None last_pipeline = None for ix, line in enumerate(markdown_lines): line = line.rstrip() if line == '```': if code_lines is None: code_lines = [] test_type = 'RUN' test_name = f'test_line_{ix + 2}' else: if code_lines: if code_lines[0].startswith('- type:'): specs = yaml.load('\n'.join(code_lines), Loader=SafeLoader) if 'dependencies:' in specs: test_type = 'PARSE' is_chain = not any('input' in spec for spec in specs) if is_chain: undefined_inputs = set(['input']) else: undefined_inputs = set.union( *[extract_inputs(spec.get('input')) for spec in specs]) - set( extract_name(spec) for spec in specs) # Treat this as a fragment of a larger pipeline. # pylint: disable=not-an-iterable code_lines = [ 'pipeline:', ' type: chain' if is_chain else '', ' transforms:', ] + [ ' - {type: ReadFromCsv, name: "%s", config: {path: x}}' % undefined_input for undefined_input in undefined_inputs ] + [' ' + line for line in code_lines] if code_lines[0] == 'pipeline:': yaml_pipeline = '\n'.join(code_lines) last_pipeline = yaml_pipeline if 'providers:' in yaml_pipeline or 'tests:' in yaml_pipeline: test_type = 'PARSE' yield test_name, create_test_method( test_type, test_name, yaml_pipeline) if 'tests:' in code_lines: test_spec = '\n'.join(code_lines) if code_lines[0] == 'pipeline:': yaml_pipeline = '\n'.join(code_lines) else: yaml_pipeline = last_pipeline for sub_ix, test_spec in enumerate(yaml.load( '\n'.join(code_lines), Loader=yaml_utils.SafeLineLoader)['tests']): suffix = test_spec.get('name', str(sub_ix)) yield ( test_name + '_' + suffix, # The yp=... ts=... is to capture the looped closure values. lambda _, yp=yaml_pipeline, ts=test_spec: yaml_testing. run_test(yp, ts)) code_lines = None elif code_lines is not None: code_lines.append(line) def createTestSuite(name, path): with open(path) as readme: return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme))) # These are copied from $ROOT/website/www/site/content/en/documentation/sdks # at build time. YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs')) ReadMeTest = createTestSuite( 'ReadMeTest', os.path.join(YAML_DOCS_DIR, 'yaml.md')) ErrorHandlingTest = createTestSuite( 'ErrorHandlingTest', os.path.join(YAML_DOCS_DIR, 'yaml-errors.md')) MappingTest = createTestSuite( 'MappingTest', os.path.join(YAML_DOCS_DIR, 'yaml-udf.md')) CombineTest = createTestSuite( 'CombineTest', os.path.join(YAML_DOCS_DIR, 'yaml-combine.md')) InlinePythonTest = createTestSuite( 'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md')) JoinTest = createTestSuite( 'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md')) TestingTest = createTestSuite( 'TestingTest', os.path.join(YAML_DOCS_DIR, 'yaml-testing.md')) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--render_dir', default=None) known_args, unknown_args = parser.parse_known_args(sys.argv) if known_args.render_dir: RENDER_DIR = known_args.render_dir logging.getLogger().setLevel(logging.INFO) unittest.main(argv=unknown_args)