From 5cfc28089baa6106cf30f9efb268231792da9251 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Wed, 31 Jan 2024 15:43:53 -0800 Subject: [PATCH 1/8] feat: grouping sets, rollup and cube compatibility --- sqlalchemy_bigquery/base.py | 34 ++++++++++++++++++-- tests/unit/test_compiler.py | 64 +++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index f4266f13..4548170a 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -38,6 +38,7 @@ import sqlalchemy import sqlalchemy.sql.expression import sqlalchemy.sql.functions +from sqlalchemy.sql.functions import rollup, cube, grouping_sets import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError, NoSuchColumnError @@ -340,9 +341,36 @@ def visit_label(self, *args, within_group_by=False, **kwargs): return super(BigQueryCompiler, self).visit_label(*args, **kwargs) def group_by_clause(self, select, **kw): - return super(BigQueryCompiler, self).group_by_clause( - select, **kw, within_group_by=True - ) + grouping_sets_exprs = [] + rollup_exprs = [] + cube_exprs = [] + + # Traverse select statement to extract grouping sets, rollup, and cube expressions + for expr in select._group_by_clause: + if isinstance(expr, grouping_sets): + grouping_sets_exprs.append( + self.process(expr.clauses) + ) # Assuming SQLAlchemy syntax + elif isinstance(expr, rollup): # Assuming SQLAlchemy syntax + rollup_exprs.append(self.process(expr.clauses)) + elif isinstance(expr, cube): # Assuming SQLAlchemy syntax + cube_exprs.append(self.process(expr.clauses)) + else: + # Handle regular group by expressions + pass + + clause = super(BigQueryCompiler, self).group_by_clause(select, **kw) + + if grouping_sets_exprs: + clause = ( + f"GROUP BY {clause} GROUPING SETS ({', '.join(grouping_sets_exprs)})" + ) + if rollup_exprs: + clause = f"GROUP BY {clause} ROLLUP ({', '.join(rollup_exprs)})" + if cube_exprs: + clause = f"GROUP BY {clause} CUBE ({', '.join(cube_exprs)})" + + return clause ############################################################################ # Handle parameters in in diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 139b6cbc..dc5d4438 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -22,6 +22,7 @@ from .conftest import setup_table from .conftest import sqlalchemy_1_4_or_higher, sqlalchemy_before_1_4 +from sqlalchemy.sql.functions import rollup, cube, grouping_sets def test_constraints_are_ignored(faux_conn, metadata): @@ -278,3 +279,66 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata) ) found_outer_sql = q.compile(faux_conn).string assert found_outer_sql == expected_outer_sql + + +def test_grouping_sets(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.Integer), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_sets(table.c.foo, table.c.bar) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY GROUPING SETS ((`table1`.`foo`), (`table1`.`bar`))" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql + + +def test_rollup(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.Integer), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + rollup(table.c.foo, table.c.bar) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql + + +def test_cube(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.Integer), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + cube(table.c.foo, table.c.bar) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql From e82f5ddcd20bd0cb40cabee28a2aaeea220a5396 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Thu, 1 Feb 2024 10:32:14 -0800 Subject: [PATCH 2/8] test commit to run kokooro tests --- testing/constraints-3.8.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index 667a747d..2aa0aa7f 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -3,7 +3,7 @@ # List *all* library dependencies and extras in this file. # Pin the version to the lower bound. # -# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", +# e.g., if setup.py has "foo >= 1.14.1, < 2.0.0dev", sqlalchemy==1.4.16 google-auth==1.25.0 google-cloud-bigquery==3.3.6 From ece7f1fb77c64c658af311c8e46f1daaaee0b6fb Mon Sep 17 00:00:00 2001 From: kiraksi Date: Thu, 1 Feb 2024 10:58:53 -0800 Subject: [PATCH 3/8] removed unnecessary clause function changes, edited tests --- sqlalchemy_bigquery/base.py | 33 ++------------------------------- testing/constraints-3.8.txt | 2 +- tests/unit/test_compiler.py | 2 +- 3 files changed, 4 insertions(+), 33 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 0ba602f9..3c191cce 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -342,37 +342,8 @@ def visit_label(self, *args, within_group_by=False, **kwargs): kwargs["render_label_as_label"] = args[0] return super(BigQueryCompiler, self).visit_label(*args, **kwargs) - def group_by_clause(self, select, **kw): - grouping_sets_exprs = [] - rollup_exprs = [] - cube_exprs = [] - - # Traverse select statement to extract grouping sets, rollup, and cube expressions - for expr in select._group_by_clause: - if isinstance(expr, grouping_sets): - grouping_sets_exprs.append( - self.process(expr.clauses) - ) # Assuming SQLAlchemy syntax - elif isinstance(expr, rollup): # Assuming SQLAlchemy syntax - rollup_exprs.append(self.process(expr.clauses)) - elif isinstance(expr, cube): # Assuming SQLAlchemy syntax - cube_exprs.append(self.process(expr.clauses)) - else: - # Handle regular group by expressions - pass - - clause = super(BigQueryCompiler, self).group_by_clause(select, **kw) - - if grouping_sets_exprs: - clause = ( - f"GROUP BY {clause} GROUPING SETS ({', '.join(grouping_sets_exprs)})" - ) - if rollup_exprs: - clause = f"GROUP BY {clause} ROLLUP ({', '.join(rollup_exprs)})" - if cube_exprs: - clause = f"GROUP BY {clause} CUBE ({', '.join(cube_exprs)})" - - return clause + def group_by_clause(self, select, **kwargs): + return super(BigQueryCompiler, self).group_by_clause(select, **kwargs) ############################################################################ # Handle parameters in in diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index 2aa0aa7f..667a747d 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -3,7 +3,7 @@ # List *all* library dependencies and extras in this file. # Pin the version to the lower bound. # -# e.g., if setup.py has "foo >= 1.14.1, < 2.0.0dev", +# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", sqlalchemy==1.4.16 google-auth==1.25.0 google-cloud-bigquery==3.3.6 diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 4857308e..1b22b31f 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -297,7 +297,7 @@ def test_grouping_sets(faux_conn, metadata): expected_sql = ( "SELECT `table1`.`foo`, `table1`.`bar` \n" - "FROM `table1` GROUP BY GROUPING SETS ((`table1`.`foo`), (`table1`.`bar`))" + "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)" ) found_sql = q.compile(faux_conn).string assert found_sql == expected_sql From 68afc3959d988b77a244699e7d5e4f39bd233917 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Thu, 1 Feb 2024 11:06:06 -0800 Subject: [PATCH 4/8] test basic implementation of group_by_clause and visit_label --- sqlalchemy_bigquery/base.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 3c191cce..61808998 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -38,7 +38,6 @@ import sqlalchemy import sqlalchemy.sql.expression import sqlalchemy.sql.functions -from sqlalchemy.sql.functions import rollup, cube, grouping_sets import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError, NoSuchColumnError @@ -333,13 +332,7 @@ def visit_column( return self.preparer.quote(tablename) + "." + name - def visit_label(self, *args, within_group_by=False, **kwargs): - # Use labels in GROUP BY clause. - # - # Flag set in the group_by_clause method. Works around missing - # equivalent to supports_simple_order_by_label for group by. - if within_group_by: - kwargs["render_label_as_label"] = args[0] + def visit_label(self, *args, **kwargs): return super(BigQueryCompiler, self).visit_label(*args, **kwargs) def group_by_clause(self, select, **kwargs): From bd38a5e14145ca77ec462e7cf4e9a989f2eb1fb3 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Mon, 5 Feb 2024 23:44:59 -0800 Subject: [PATCH 5/8] fixed render label as label assignment --- sqlalchemy_bigquery/base.py | 20 +++++++++++++++++--- tests/unit/test_compiler.py | 6 +++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 61808998..4d232bc0 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -332,11 +332,25 @@ def visit_column( return self.preparer.quote(tablename) + "." + name - def visit_label(self, *args, **kwargs): + def visit_label(self, *args, within_group_by=False, **kwargs): + # Use labels in GROUP BY clause. + # + # Flag set in the group_by_clause method. Works around missing + # equivalent to supports_simple_order_by_label for group by. + if within_group_by: + if all( + keyword not in str(args[0]) + for keyword in ("GROUPING SETS", "ROLLUP", "CUBE") + ): + kwargs["render_label_as_label"] = args[0] return super(BigQueryCompiler, self).visit_label(*args, **kwargs) - def group_by_clause(self, select, **kwargs): - return super(BigQueryCompiler, self).group_by_clause(select, **kwargs) + def group_by_clause(self, select, **kw): + return super(BigQueryCompiler, self).group_by_clause( + select, + **kw, + within_group_by=True, + ) ############################################################################ # Handle parameters in in diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 1b22b31f..55157537 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -288,7 +288,7 @@ def test_grouping_sets(faux_conn, metadata): "table1", metadata, sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), ) q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( @@ -309,7 +309,7 @@ def test_rollup(faux_conn, metadata): "table1", metadata, sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), ) q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( @@ -330,7 +330,7 @@ def test_cube(faux_conn, metadata): "table1", metadata, sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), ) q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( From 515481458952deb31a280c3b3987c67e694062fb Mon Sep 17 00:00:00 2001 From: kiraksi Date: Tue, 6 Feb 2024 13:51:51 -0800 Subject: [PATCH 6/8] added test case --- sqlalchemy_bigquery/base.py | 5 ++--- tests/unit/test_compiler.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 4d232bc0..4987b914 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -343,13 +343,12 @@ def visit_label(self, *args, within_group_by=False, **kwargs): for keyword in ("GROUPING SETS", "ROLLUP", "CUBE") ): kwargs["render_label_as_label"] = args[0] + return super(BigQueryCompiler, self).visit_label(*args, **kwargs) def group_by_clause(self, select, **kw): return super(BigQueryCompiler, self).group_by_clause( - select, - **kw, - within_group_by=True, + select, **kw, within_group_by=True ) ############################################################################ diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 55157537..903e7195 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -26,6 +26,7 @@ sqlalchemy_before_2_0, ) from sqlalchemy.sql.functions import rollup, cube, grouping_sets +from sqlalchemy import func def test_constraints_are_ignored(faux_conn, metadata): @@ -343,3 +344,24 @@ def test_cube(faux_conn, metadata): ) found_sql = q.compile(faux_conn).string assert found_sql == expected_sql + + +def test_multiple_grouping_sets(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql From 033d3294d52e03ee2c9f72a8dac5ac2bd6da8e4e Mon Sep 17 00:00:00 2001 From: kiraksi Date: Thu, 8 Feb 2024 09:47:44 -0800 Subject: [PATCH 7/8] reformat logic --- sqlalchemy_bigquery/base.py | 12 +++++++----- tests/unit/test_compiler.py | 1 - 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 4987b914..765ddb67 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -338,11 +338,13 @@ def visit_label(self, *args, within_group_by=False, **kwargs): # Flag set in the group_by_clause method. Works around missing # equivalent to supports_simple_order_by_label for group by. if within_group_by: - if all( - keyword not in str(args[0]) - for keyword in ("GROUPING SETS", "ROLLUP", "CUBE") - ): - kwargs["render_label_as_label"] = args[0] + column_label = args[0] + sql_keywords = {"GROUPING SETS", "ROLLUP", "CUBE"} + for keyword in sql_keywords: + if keyword in str(column_label): + break + else: # for/else always happens unless break gets called + kwargs["render_label_as_label"] = column_label return super(BigQueryCompiler, self).visit_label(*args, **kwargs) diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 903e7195..def13cfd 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -26,7 +26,6 @@ sqlalchemy_before_2_0, ) from sqlalchemy.sql.functions import rollup, cube, grouping_sets -from sqlalchemy import func def test_constraints_are_ignored(faux_conn, metadata): From 0c882b96eb1d54c3e4211ac53e1fba2027fb4435 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Thu, 8 Feb 2024 09:53:56 -0800 Subject: [PATCH 8/8] test commit --- sqlalchemy_bigquery/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 765ddb67..e80f2891 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -402,8 +402,6 @@ def visit_not_in_op_binary(self, binary, operator, **kw): + ")" ) - visit_notin_op_binary = visit_not_in_op_binary # before 1.4 - ############################################################################ ############################################################################