diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index bcff58be..e80f2891 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -338,7 +338,14 @@ def visit_label(self, *args, within_group_by=False, **kwargs): # Flag set in the group_by_clause method. Works around missing # equivalent to supports_simple_order_by_label for group by. if within_group_by: - kwargs["render_label_as_label"] = args[0] + column_label = args[0] + sql_keywords = {"GROUPING SETS", "ROLLUP", "CUBE"} + for keyword in sql_keywords: + if keyword in str(column_label): + break + else: # for/else always happens unless break gets called + kwargs["render_label_as_label"] = column_label + return super(BigQueryCompiler, self).visit_label(*args, **kwargs) def group_by_clause(self, select, **kw): @@ -395,8 +402,6 @@ def visit_not_in_op_binary(self, binary, operator, **kw): + ")" ) - visit_notin_op_binary = visit_not_in_op_binary # before 1.4 - ############################################################################ ############################################################################ diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 5ac71485..def13cfd 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -25,6 +25,7 @@ sqlalchemy_2_0_or_higher, sqlalchemy_before_2_0, ) +from sqlalchemy.sql.functions import rollup, cube, grouping_sets def test_constraints_are_ignored(faux_conn, metadata): @@ -279,3 +280,87 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata) ) found_outer_sql = q.compile(faux_conn).string assert found_outer_sql == expected_outer_sql + + +def test_grouping_sets(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_sets(table.c.foo, table.c.bar) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql + + +def test_rollup(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + rollup(table.c.foo, table.c.bar) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql + + +def test_cube(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + cube(table.c.foo, table.c.bar) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql + + +def test_multiple_grouping_sets(faux_conn, metadata): + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( + grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo) + ) + + expected_sql = ( + "SELECT `table1`.`foo`, `table1`.`bar` \n" + "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)" + ) + found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql