# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Tests for transforms that use the SQL Expansion service.""" # pytype: skip-file import logging import typing import unittest import pytest import apache_beam as beam from apache_beam import coders from apache_beam.options.pipeline_options import StandardOptions from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.sql import SqlTransform SimpleRow = typing.NamedTuple( "SimpleRow", [("id", int), ("str", str), ("flt", float)]) coders.registry.register_coder(SimpleRow, coders.RowCoder) Enrich = typing.NamedTuple("Enrich", [("id", int), ("metadata", str)]) coders.registry.register_coder(Enrich, coders.RowCoder) Shopper = typing.NamedTuple( "Shopper", [("shopper", str), ("cart", typing.Mapping[str, int])]) coders.registry.register_coder(Shopper, coders.RowCoder) class Aribitrary: def __init__(self, obj): self.obj = obj def __eq__(self, other): return self.obj == other.obj UserTypeRow = typing.NamedTuple( "UserTypeRow", [("id", int), ("arb", Aribitrary), ("complex", complex)]) coders.registry.register_coder(UserTypeRow, coders.RowCoder) @pytest.mark.xlang_sql_expansion_service @unittest.skipIf( TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is None, "Must be run with a runner that supports staging java artifacts.") class SqlTransformTest(unittest.TestCase): """Tests that exercise the cross-language SqlTransform (implemented in java). Note this test must be executed with pipeline options that run jobs on a local job server. The easiest way to accomplish this is to run the `validatesCrossLanguageRunnerPythonUsingSql` gradle target for a particular job server, which will start the runner and job server for you. For example, `:runners:flink:1.13:job-server:validatesCrossLanguageRunnerPythonUsingSql` to test on Flink 1.13. Alternatively, you may be able to iterate faster if you run the tests directly using a runner like `FlinkRunner`, which can start a local Flink cluster and job server for you: $ pip install -e './sdks/python[gcp,test]' $ pytest apache_beam/transforms/sql_test.py \\ --test-pipeline-options="--runner=FlinkRunner" """ _multiprocess_can_split_ = True def test_generate_data(self): with TestPipeline() as p: out = p | SqlTransform( """SELECT CAST(1 AS INT) AS `id`, CAST('foo' AS VARCHAR) AS `str`, CAST(3.14 AS DOUBLE) AS `flt`""") assert_that(out, equal_to([(1, "foo", 3.14)])) def test_project(self): with TestPipeline() as p: out = ( p | beam.Create([SimpleRow(1, "foo", 3.14)]) | SqlTransform("SELECT `id`, `flt` FROM PCOLLECTION")) assert_that(out, equal_to([(1, 3.14)])) def test_filter(self): with TestPipeline() as p: out = ( p | beam.Create([SimpleRow(1, "foo", 3.14), SimpleRow(2, "bar", 1.414)]) | SqlTransform("SELECT * FROM PCOLLECTION WHERE `str` = 'bar'")) assert_that(out, equal_to([(2, "bar", 1.414)])) def test_agg(self): with TestPipeline() as p: out = ( p | beam.Create([ SimpleRow(1, "foo", 1.), SimpleRow(1, "foo", 2.), SimpleRow(1, "foo", 3.), SimpleRow(2, "bar", 1.414), SimpleRow(2, "bar", 1.414), SimpleRow(2, "bar", 1.414), SimpleRow(2, "bar", 1.414), ]) | SqlTransform( """ SELECT `str`, COUNT(*) AS `count`, SUM(`id`) AS `sum`, AVG(`flt`) AS `avg` FROM PCOLLECTION GROUP BY `str`""")) assert_that(out, equal_to([("foo", 3, 3, 2), ("bar", 4, 8, 1.414)])) def test_tagged_join(self): with TestPipeline() as p: enrich = ( p | "Create enrich" >> beam.Create( [Enrich(1, "a"), Enrich(2, "b"), Enrich(26, "z")])) simple = ( p | "Create simple" >> beam.Create([ SimpleRow(1, "foo", 3.14), SimpleRow(26, "bar", 1.11), SimpleRow(1, "baz", 2.34) ])) out = ({ 'simple': simple, 'enrich': enrich } | SqlTransform( """ SELECT simple.`id` AS `id`, enrich.metadata AS metadata FROM simple JOIN enrich ON simple.`id` = enrich.`id`""")) assert_that(out, equal_to([(1, "a"), (26, "z"), (1, "a")])) def test_row(self): with TestPipeline() as p: out = ( p | beam.Create([1, 2, 10]) | beam.Map(lambda x: beam.Row(a=x, b=str(x))) | SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION")) assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)])) def test_row_user_type(self): with TestPipeline() as p: out = ( p | beam.Create([ UserTypeRow(1, Aribitrary(1.0), 1 + 2.5j), UserTypeRow(1, Aribitrary("abc"), -1j), ]) | SqlTransform("SELECT arb, complex FROM PCOLLECTION") | beam.Map(tuple)) assert_that( out, equal_to([(Aribitrary(1.0), 1 + 2.5j), (Aribitrary("abc"), -1j)])) def test_windowing_before_sql(self): with TestPipeline() as p: out = ( p | beam.Create([ SimpleRow(5, "foo", 1.), SimpleRow(15, "bar", 2.), SimpleRow(25, "baz", 3.) ]) | beam.Map(lambda v: beam.window.TimestampedValue(v, v.id)). with_output_types(SimpleRow) | beam.WindowInto( beam.window.FixedWindows(10)).with_output_types(SimpleRow) | SqlTransform("SELECT COUNT(*) as `count` FROM PCOLLECTION")) assert_that(out, equal_to([(1, ), (1, ), (1, )])) def test_map(self): with TestPipeline() as p: out = ( p | beam.Create([ Shopper('bob', { 'bananas': 6, 'cherries': 3 }), Shopper('alice', { 'apples': 2, 'bananas': 3 }) ]).with_output_types(Shopper) | SqlTransform("SELECT * FROM PCOLLECTION WHERE shopper = 'alice'")) assert_that(out, equal_to([('alice', {'apples': 2, 'bananas': 3})])) def test_sql_ddl_set_option(self): with TestPipeline() as p: input_data = [ beam.Row(id=1, value=10), beam.Row(id=2, value=20), beam.Row(id=3, value=30) ] # DDL uses SET to modify a session option (tests DDL parsing) # Using a known Calcite option like sqlConformance ddl_statement = """ SET sqlConformance = 'LENIENT' """ # Query still operates on the implicit PCOLLECTION query_statement = "SELECT * FROM PCOLLECTION WHERE id > 2" # Input PCollection is piped directly out = ( p | beam.Create(input_data) # Pass both the query and the DDL | SqlTransform(query=query_statement, ddl=ddl_statement)) # Verify the output matches the query (unaffected by the SET DDL) assert_that(out, equal_to([(3, 30)])) if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) unittest.main()