From 904a3bb58eaf852e42615b7e4b0b01de5d576ee5 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Fri, 20 Aug 2021 14:02:55 -0400 Subject: [PATCH 01/10] sort and add ARRAY to type list --- sqlalchemy_bigquery/__init__.py | 42 +++++++++++++++++---------------- sqlalchemy_bigquery/base.py | 42 +++++++++++++++++---------------- 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/sqlalchemy_bigquery/__init__.py b/sqlalchemy_bigquery/__init__.py index e3dd3f2d..8b161871 100644 --- a/sqlalchemy_bigquery/__init__.py +++ b/sqlalchemy_bigquery/__init__.py @@ -24,40 +24,42 @@ from .base import BigQueryDialect from .base import ( - STRING, + ARRAY, + BIGNUMERIC, BOOL, BOOLEAN, + BYTES, + DATE, + DATETIME, + FLOAT, + FLOAT64, INT64, INTEGER, - FLOAT64, - FLOAT, - TIMESTAMP, - DATETIME, - DATE, - BYTES, - TIME, - RECORD, NUMERIC, - BIGNUMERIC, + RECORD, + STRING, + TIME, + TIMESTAMP, ) __all__ = [ + "ARRAY", + "BIGNUMERIC", "BigQueryDialect", - "STRING", "BOOL", "BOOLEAN", + "BYTES", + "DATE", + "DATETIME", + "FLOAT", + "FLOAT64", "INT64", "INTEGER", - "FLOAT64", - "FLOAT", - "TIMESTAMP", - "DATETIME", - "DATE", - "BYTES", - "TIME", - "RECORD", "NUMERIC", - "BIGNUMERIC", + "RECORD", + "STRING", + "TIME", + "TIMESTAMP", ] try: diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index a51e4748..7eb6c8ec 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -114,39 +114,41 @@ def format_label(self, label, name=None): _type_map = { - "STRING": types.String, - "BOOL": types.Boolean, + "ARRAY": types.ARRAY, + "BIGNUMERIC": types.Numeric, "BOOLEAN": types.Boolean, - "INT64": types.Integer, - "INTEGER": types.Integer, + "BOOL": types.Boolean, + "BYTES": types.BINARY, + "DATETIME": types.DATETIME, + "DATE": types.DATE, "FLOAT64": types.Float, "FLOAT": types.Float, + "INT64": types.Integer, + "INTEGER": types.Integer, + "NUMERIC": types.Numeric, + "RECORD": types.JSON, + "STRING": types.String, "TIMESTAMP": types.TIMESTAMP, - "DATETIME": types.DATETIME, - "DATE": types.DATE, - "BYTES": types.BINARY, "TIME": types.TIME, - "RECORD": types.JSON, - "NUMERIC": types.Numeric, - "BIGNUMERIC": types.Numeric, } # By convention, dialect-provided types are spelled with all upper case. -STRING = _type_map["STRING"] -BOOL = _type_map["BOOL"] +ARRAY = _type_map["ARRAY"] +BIGNUMERIC = _type_map["NUMERIC"] BOOLEAN = _type_map["BOOLEAN"] -INT64 = _type_map["INT64"] -INTEGER = _type_map["INTEGER"] +BOOL = _type_map["BOOL"] +BYTES = _type_map["BYTES"] +DATETIME = _type_map["DATETIME"] +DATE = _type_map["DATE"] FLOAT64 = _type_map["FLOAT64"] FLOAT = _type_map["FLOAT"] +INT64 = _type_map["INT64"] +INTEGER = _type_map["INTEGER"] +NUMERIC = _type_map["NUMERIC"] +RECORD = _type_map["RECORD"] +STRING = _type_map["STRING"] TIMESTAMP = _type_map["TIMESTAMP"] -DATETIME = _type_map["DATETIME"] -DATE = _type_map["DATE"] -BYTES = _type_map["BYTES"] TIME = _type_map["TIME"] -RECORD = _type_map["RECORD"] -NUMERIC = _type_map["NUMERIC"] -BIGNUMERIC = _type_map["NUMERIC"] try: _type_map["GEOGRAPHY"] = GEOGRAPHY From ffcb3e7f9e9c8947127df1b6a76bad434cfc0eed Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Fri, 20 Aug 2021 14:22:26 -0400 Subject: [PATCH 02/10] update constraint for biqquery --- testing/constraints-3.6.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 1785edd0..e5ed0b2a 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -6,5 +6,5 @@ # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", sqlalchemy==1.2.0 google-auth==1.25.0 -google-cloud-bigquery==2.19.0 +google-cloud-bigquery==2.24.1 google-api-core==1.30.0 From 806aa5ca08ecdda8b5d572cf77c328c0cbed3b36 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Fri, 20 Aug 2021 14:59:26 -0400 Subject: [PATCH 03/10] fix unnest by adding table aliases when needed --- sqlalchemy_bigquery/base.py | 41 +++++++++++++++++++ tests/system/test_sqlalchemy_bigquery.py | 29 ++++++++++++++ tests/unit/test_select.py | 51 ++++++++++++++++++++++++ 3 files changed, 121 insertions(+) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 7eb6c8ec..2657998d 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -62,6 +62,8 @@ FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+") +TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases" + def assert_(cond, message="Assertion failed"): # pragma: NO COVER if not cond: @@ -248,6 +250,38 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw): insert_stmt, asfrom=False, **kw ) + def visit_table_valued_alias(self, element, **kw): + # When using table-valued functions, like UNNEST, BigQuery requires a + # FROM for any table referenced in the function, including expressions + # in function arguments. + # + # This is tricky because: + # 1. We have to find the table references. + # 2. We can't know practically if there's already a FROM for a table. + # + # We leverage visit_column to find a table reference. Whenever we find + # one, we create an alias for it. + # + # This requires communicating between this function and visit_column. + # We do this by sticking a dictionary in the keyword arguments. + # This dictionary: + # a. Tells visit_column that it's an a table-valued alias expresssion, and + # b. Gives it a place to record the aliases it creates. + # + # This function creates aliases in the FROM list for any aliases recorded + # by visit_column. + + kw[TABLE_VALUED_ALIAS_ALIASES] = {} + ret = super().visit_table_valued_alias(element, **kw) + aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES) + if aliases: + aliases = ", ".join( + f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}" + for tablename, alias in aliases.items() + ) + ret = f"{aliases}, {ret}" + return ret + def visit_column( self, column, add_to_result_map=None, include_table=True, **kwargs ): @@ -273,6 +307,13 @@ def visit_column( tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) + elif TABLE_VALUED_ALIAS_ALIASES in kwargs: + aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES] + if tablename not in aliases: + aliases[tablename] = self.anon_map[ + f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}" + ] + tablename = aliases[tablename] return self.preparer.quote(tablename) + "." + name def visit_label(self, *args, within_group_by=False, **kwargs): diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 06024368..f49e5ae5 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -34,6 +34,7 @@ import datetime import decimal +sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) ONE_ROW_CONTENTS_EXPANDED = [ 588, @@ -691,3 +692,31 @@ def test_has_table(engine, engine_using_test_dataset, bigquery_dataset): assert engine_using_test_dataset.has_table(f"{bigquery_dataset}.sample") is True assert engine_using_test_dataset.has_table("sample_alt") is False + + +@pytest.mark.skipif( + sqlalchemy_version_info < (1, 4), + reason="unnest (and other table-valued-function) support required version 1.4", +) +def test_unnest(engine, bigquery_dataset): + from sqlalchemy import select, func, String + from sqlalchemy_bigquery import ARRAY + + conn = engine.connect() + metadata = MetaData() + table = Table( + f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)), + ) + metadata.create_all(engine) + conn.execute( + table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])] + ) + query = select([func.unnest(table.c.objects).alias("foo_objects").column]) + compiled = str(query.compile(engine)) + assert " ".join(compiled.strip().split()) == ( + f"SELECT `foo_objects`" + f" FROM" + f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`," + f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`" + ) + assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"] diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index 5d49ae68..4dd67a1c 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -356,3 +356,54 @@ def test_select_notin_param_empty(faux_conn): else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}, ) + + +@sqlalchemy_1_4_or_higher +@pytest.mark.parametrize("alias", [True, False]) +def test_unnest(faux_conn, alias): + from sqlalchemy import String + from sqlalchemy_bigquery import ARRAY + + table = setup_table(faux_conn, "t", sqlalchemy.Column("objects", ARRAY(String))) + fcall = sqlalchemy.func.unnest(table.c.objects) + if alias: + query = fcall.alias("foo_objects").column + else: + query = fcall.column_valued("foo_objects") + compiled = str(sqlalchemy.select(query).compile(faux_conn.engine)) + assert " ".join(compiled.strip().split()) == ( + "SELECT `foo_objects` FROM `t` `t_1`, unnest(`t_1`.`objects`) AS `foo_objects`" + ) + + +@sqlalchemy_1_4_or_higher +@pytest.mark.parametrize("alias", [True, False]) +def test_table_valued_alias_w_multiple_references_to_the_same_table(faux_conn, alias): + from sqlalchemy import String + from sqlalchemy_bigquery import ARRAY + + table = setup_table(faux_conn, "t", sqlalchemy.Column("objects", ARRAY(String))) + fcall = sqlalchemy.func.foo(table.c.objects, table.c.objects) + if alias: + query = fcall.alias("foo_objects").column + else: + query = fcall.column_valued("foo_objects") + compiled = str(sqlalchemy.select(query).compile(faux_conn.engine)) + assert " ".join(compiled.strip().split()) == ( + "SELECT `foo_objects` " + "FROM `t` `t_1`, foo(`t_1`.`objects`, `t_1`.`objects`) AS `foo_objects`" + ) + + +@sqlalchemy_1_4_or_higher +@pytest.mark.parametrize("alias", [True, False]) +def test_unnest_w_no_table_references(faux_conn, alias): + fcall = sqlalchemy.func.unnest([1, 2, 3]) + if alias: + query = fcall.alias().column + else: + query = fcall.column_valued() + compiled = str(sqlalchemy.select(query).compile(faux_conn.engine)) + assert " ".join(compiled.strip().split()) == ( + "SELECT `anon_1` FROM unnest(%(unnest_1)s) AS `anon_1`" + ) From 6a679aebe8d4d6841e9f34de3fe9c5dc64cc6366 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Mon, 23 Aug 2021 14:21:11 -0600 Subject: [PATCH 04/10] fix mis-merge --- tests/unit/test_select.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index e8b99bb3..61084540 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -373,6 +373,7 @@ def nstr(q): assert ( nstr(q.compile(faux_conn.engine, compile_kwargs={"literal_binds": True})) == "SELECT `test`.`val` FROM `test` WHERE `test`.`val` IN (2)" + ) @sqlalchemy_1_4_or_higher From 11ca126fe7aef0815557b7d2031f3e5ee18decd6 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Tue, 24 Aug 2021 09:48:47 -0600 Subject: [PATCH 05/10] Update sqlalchemy_bigquery/base.py Co-authored-by: Tim Swast --- sqlalchemy_bigquery/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 570c2748..98edfb9e 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -277,7 +277,7 @@ def visit_table_valued_alias(self, element, **kw): # 2. We can't know practically if there's already a FROM for a table. # # We leverage visit_column to find a table reference. Whenever we find - # one, we create an alias for it, so as not to conlfict with an existing + # one, we create an alias for it, so as not to conflict with an existing # reference if one is present. # # This requires communicating between this function and visit_column. From 32980babf2fe85ac1317c840b889bc8ef76cfcad Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Tue, 24 Aug 2021 14:17:43 -0600 Subject: [PATCH 06/10] Use packaging.version.parse to parse versions --- setup.py | 1 + .../test_dialect_compliance.py | 5 ++++- tests/system/test_sqlalchemy_bigquery.py | 5 ++--- tests/unit/conftest.py | 12 ++++++++---- tests/unit/test_select.py | 7 +++++-- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index fd8a8acd..4ec8d059 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ def readme(): # https://github.com/googleapis/google-cloud-python/issues/10566 "google-auth>=1.25.0,<3.0.0dev", # Work around pip wack. "google-cloud-bigquery>=2.24.1", + "packaging", "sqlalchemy>=1.2.0,<1.5.0dev", "future", ], diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py index c126c4f7..5c708b78 100644 --- a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -19,6 +19,7 @@ import datetime import mock +import packaging.version import pytest import pytz import sqlalchemy @@ -41,7 +42,9 @@ ) -if sqlalchemy.__version__ < "1.4": +if (packaging.version.parse(sqlalchemy.__version__) + < packaging.version.parse("1.4") +): from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest class LimitOffsetTest(_LimitOffsetTest): diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index cb5ba483..667fae54 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -28,14 +28,13 @@ from sqlalchemy.sql import expression, select, literal_column from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import sessionmaker +import packaging.version from pytz import timezone import pytest import sqlalchemy import datetime import decimal -sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) - ONE_ROW_CONTENTS_EXPANDED = [ 588, datetime.datetime(2013, 10, 10, 11, 27, 16, tzinfo=timezone("UTC")), @@ -729,7 +728,7 @@ class MyTable(Base): @pytest.mark.skipif( - sqlalchemy_version_info < (1, 4), + packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse('1.4'), reason="unnest (and other table-valued-function) support required version 1.4", ) def test_unnest(engine, bigquery_dataset): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e5de882d..53cb5431 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -21,20 +21,24 @@ import mock import sqlite3 +import packaging.version import pytest import sqlalchemy import fauxdbi -sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) +sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__) sqlalchemy_1_3_or_higher = pytest.mark.skipif( - sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher" + sqlalchemy_version < packaging.version.parse('1.3'), + reason="requires sqlalchemy 1.3 or higher" ) sqlalchemy_1_4_or_higher = pytest.mark.skipif( - sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher" + sqlalchemy_version < packaging.version.parse('1.4'), + reason="requires sqlalchemy 1.4 or higher" ) sqlalchemy_before_1_4 = pytest.mark.skipif( - sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower" + sqlalchemy_version >= packaging.version.parse('1.4'), + reason="requires sqlalchemy 1.3 or lower" ) diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index 61084540..27bc3f8d 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -20,6 +20,7 @@ import datetime from decimal import Decimal +import packaging.version import pytest import sqlalchemy @@ -292,7 +293,8 @@ def test_select_in_param_empty(faux_conn): assert not isin assert faux_conn.test_data["execute"][-1] == ( "SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`" - if sqlalchemy.__version__ >= "1.4" + if (packaging.version.parse(sqlalchemy.__version__) + >= packaging.version.parse("1.4")) else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}, ) @@ -352,7 +354,8 @@ def test_select_notin_param_empty(faux_conn): assert isnotin assert faux_conn.test_data["execute"][-1] == ( "SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`" - if sqlalchemy.__version__ >= "1.4" + if (packaging.version.parse(sqlalchemy.__version__) + >= packaging.version.parse("1.4")) else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}, ) From 4de3d7107abab906bd5a863cd23c4e7637accf4f Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Tue, 24 Aug 2021 14:19:17 -0600 Subject: [PATCH 07/10] blacken --- .../test_dialect_compliance.py | 4 +--- tests/system/test_sqlalchemy_bigquery.py | 2 +- tests/unit/conftest.py | 12 ++++++------ tests/unit/test_select.py | 12 ++++++++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py index 5c708b78..156e6167 100644 --- a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -42,9 +42,7 @@ ) -if (packaging.version.parse(sqlalchemy.__version__) - < packaging.version.parse("1.4") -): +if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"): from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest class LimitOffsetTest(_LimitOffsetTest): diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 667fae54..63dc220b 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -728,7 +728,7 @@ class MyTable(Base): @pytest.mark.skipif( - packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse('1.4'), + packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"), reason="unnest (and other table-valued-function) support required version 1.4", ) def test_unnest(engine, bigquery_dataset): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 53cb5431..886e9aee 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -29,16 +29,16 @@ sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__) sqlalchemy_1_3_or_higher = pytest.mark.skipif( - sqlalchemy_version < packaging.version.parse('1.3'), - reason="requires sqlalchemy 1.3 or higher" + sqlalchemy_version < packaging.version.parse("1.3"), + reason="requires sqlalchemy 1.3 or higher", ) sqlalchemy_1_4_or_higher = pytest.mark.skipif( - sqlalchemy_version < packaging.version.parse('1.4'), - reason="requires sqlalchemy 1.4 or higher" + sqlalchemy_version < packaging.version.parse("1.4"), + reason="requires sqlalchemy 1.4 or higher", ) sqlalchemy_before_1_4 = pytest.mark.skipif( - sqlalchemy_version >= packaging.version.parse('1.4'), - reason="requires sqlalchemy 1.3 or lower" + sqlalchemy_version >= packaging.version.parse("1.4"), + reason="requires sqlalchemy 1.3 or lower", ) diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index 27bc3f8d..10669864 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -293,8 +293,10 @@ def test_select_in_param_empty(faux_conn): assert not isin assert faux_conn.test_data["execute"][-1] == ( "SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`" - if (packaging.version.parse(sqlalchemy.__version__) - >= packaging.version.parse("1.4")) + if ( + packaging.version.parse(sqlalchemy.__version__) + >= packaging.version.parse("1.4") + ) else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}, ) @@ -354,8 +356,10 @@ def test_select_notin_param_empty(faux_conn): assert isnotin assert faux_conn.test_data["execute"][-1] == ( "SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`" - if (packaging.version.parse(sqlalchemy.__version__) - >= packaging.version.parse("1.4")) + if ( + packaging.version.parse(sqlalchemy.__version__) + >= packaging.version.parse("1.4") + ) else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}, ) From 5cde0d4c7724bd0e59833a1a4549d55fe0cb23a3 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Tue, 24 Aug 2021 14:21:13 -0600 Subject: [PATCH 08/10] Oops, packaging is a test dependency --- setup.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 4ec8d059..f70c3a0d 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,11 @@ def readme(): return f.read() -extras = dict(geography=["GeoAlchemy2", "shapely"], alembic=["alembic"], tests=["pytz"]) +extras = dict( + geography=["GeoAlchemy2", "shapely"], + alembic=["alembic"], + tests=["packaging", "pytz"], +) extras["all"] = set(itertools.chain.from_iterable(extras.values())) setup( @@ -80,13 +84,12 @@ def readme(): # https://github.com/googleapis/google-cloud-python/issues/10566 "google-auth>=1.25.0,<3.0.0dev", # Work around pip wack. "google-cloud-bigquery>=2.24.1", - "packaging", "sqlalchemy>=1.2.0,<1.5.0dev", "future", ], extras_require=extras, python_requires=">=3.6, <3.10", - tests_require=["pytz"], + tests_require=["packaging", "pytz"], entry_points={ "sqlalchemy.dialects": ["bigquery = sqlalchemy_bigquery:BigQueryDialect"] }, From 61f37bfc4dcc6e341bbdc1242efc5e4d676e3f33 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Wed, 25 Aug 2021 11:04:56 -0600 Subject: [PATCH 09/10] fix: the unnest function lost needed type information --- sqlalchemy_bigquery/base.py | 17 ++++++++++ tests/unit/test_sqlalchemy_bigquery.py | 47 ++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 98edfb9e..3f81046f 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -35,6 +35,8 @@ from google.api_core.exceptions import NotFound import sqlalchemy +import sqlalchemy.sql.expression +import sqlalchemy.sql.functions import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError @@ -1058,6 +1060,21 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): return view.view_query +class unnest(sqlalchemy.sql.functions.GenericFunction): + def __init__(self, *args, **kwargs): + expr = kwargs.pop("expr", None) + if expr is not None: + args = (expr,) + args + if len(args) != 1: + raise TypeError("The unnest function requires a single argument.") + arg = args[0] + if isinstance(arg, sqlalchemy.sql.expression.ColumnElement): + if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY): + raise TypeError("The argument to unnest must have an ARRAY type.") + self.type = arg.type.item_type + super().__init__(*args, **kwargs) + + dialect = BigQueryDialect try: diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index a4c81367..bceb54fb 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -10,6 +10,7 @@ from google.cloud import bigquery from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.table import TableListItem +import packaging.version import pytest import sqlalchemy @@ -178,3 +179,49 @@ def test_follow_dialect_attribute_convention(): assert sqlalchemy_bigquery.dialect is sqlalchemy_bigquery.BigQueryDialect assert sqlalchemy_bigquery.base.dialect is sqlalchemy_bigquery.BigQueryDialect + + +@pytest.mark.parametrize( + "args,kw,error", + [ + ((), {}, "The unnest function requires a single argument."), + ((1, 1), {}, "The unnest function requires a single argument."), + ((1,), {"expr": 1}, "The unnest function requires a single argument."), + ((1, 1), {"expr": 1}, "The unnest function requires a single argument."), + ( + (), + {"expr": sqlalchemy.Column("x", sqlalchemy.String)}, + "The argument to unnest must have an ARRAY type.", + ), + ( + (sqlalchemy.Column("x", sqlalchemy.String),), + {}, + "The argument to unnest must have an ARRAY type.", + ), + ], +) +def test_unnest_function_errors(args, kw, error): + import sqlalchemy_bigquery # noqa + + with pytest.raises(TypeError, match=error): + sqlalchemy.func.unnest(*args, **kw) + + +@pytest.mark.parametrize( + "args,kw", + [ + ((), {"expr": sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String))}), + ((sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String)),), {}), + ], +) +def test_unnest_function(args, kw): + import sqlalchemy_bigquery # noqa + + f = sqlalchemy.func.unnest(*args, **kw) + assert isinstance(f.type, sqlalchemy.String) + if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse( + "1.4" + ): + assert isinstance( + sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String + ) From 65309131a92250ec69634d5cc56a2b95d49e391b Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Thu, 26 Aug 2021 13:28:42 -0600 Subject: [PATCH 10/10] explain the seemingly unused imports :) --- tests/unit/test_sqlalchemy_bigquery.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index bceb54fb..75cbec42 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -201,6 +201,8 @@ def test_follow_dialect_attribute_convention(): ], ) def test_unnest_function_errors(args, kw, error): + # Make sure the unnest function is registered with SQLAlchemy, which + # happens when sqlalchemy_bigquery is imported. import sqlalchemy_bigquery # noqa with pytest.raises(TypeError, match=error): @@ -215,6 +217,8 @@ def test_unnest_function_errors(args, kw, error): ], ) def test_unnest_function(args, kw): + # Make sure the unnest function is registered with SQLAlchemy, which + # happens when sqlalchemy_bigquery is imported. import sqlalchemy_bigquery # noqa f = sqlalchemy.func.unnest(*args, **kw)