Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix: reduce bigquery table modification via DML for to_gbq #1737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dataclasses
import typing

from google.cloud import bigquery
import pyarrow as pa
import sqlglot as sg
import sqlglot.dialects.bigquery
Expand Down Expand Up @@ -104,6 +105,24 @@ def from_pyarrow(
)
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)

@classmethod
def from_query_string(
cls,
query_string: str,
) -> SQLGlotIR:
"""Builds SQLGlot expression from a query string"""
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
cte_name = sge.to_identifier(
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
)
cte = sge.CTE(
this=query_string,
alias=cte_name,
)
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
select_expr.set("with", sge.With(expressions=[cte]))
return cls(expr=select_expr, uid_gen=uid_gen)

def select(
self,
selected_cols: tuple[tuple[str, sge.Expression], ...],
Expand Down Expand Up @@ -133,6 +152,36 @@ def project(
select_expr = self.expr.select(*projected_cols_expr, append=True)
return SQLGlotIR(expr=select_expr)

def insert(
self,
destination: bigquery.TableReference,
) -> str:
return sge.insert(self.expr.subquery(), _table(destination)).sql(
dialect=self.dialect, pretty=self.pretty
)

def replace(
self,
destination: bigquery.TableReference,
) -> str:
# Workaround for SQLGlot breaking change:
# https://github.com/tobymao/sqlglot/pull/4495
whens_expr = [
sge.When(matched=False, source=True, then=sge.Delete()),
sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))),
]
whens_str = "\n".join(
when_expr.sql(dialect=self.dialect, pretty=self.pretty)
for when_expr in whens_expr
)

merge_str = sge.Merge(
this=_table(destination),
using=self.expr.subquery(),
on=_literal(False, dtypes.BOOL_DTYPE),
).sql(dialect=self.dialect, pretty=self.pretty)
return f"{merge_str}\n{whens_str}"

def _encapsulate_as_cte(
self,
) -> sge.Select:
Expand Down Expand Up @@ -190,3 +239,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:

def _cast(arg: typing.Any, to: str) -> sge.Cast:
return sge.Cast(this=arg, to=to)


def _table(table: bigquery.TableReference) -> sge.Table:
return sge.Table(
this=sg.to_identifier(table.table_id, quoted=True),
db=sg.to_identifier(table.dataset_id, quoted=True),
catalog=sg.to_identifier(table.project, quoted=True),
)
65 changes: 55 additions & 10 deletions bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@

import bigframes.core
from bigframes.core import compile, rewrite
import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir
import bigframes.core.guid
import bigframes.core.nodes as nodes
import bigframes.core.ordering as order
import bigframes.core.schema as schemata
import bigframes.core.tree_properties as tree_properties
import bigframes.dtypes
import bigframes.exceptions as bfe
Expand Down Expand Up @@ -206,17 +208,45 @@ def export_gbq(
if bigframes.options.compute.enable_multi_query_execution:
self._simplify_with_caching(array_value)

dispositions = {
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
"append": bigquery.WriteDisposition.WRITE_APPEND,
}
table_exists = True
try:
table = self.bqclient.get_table(destination)
if if_exists == "fail":
raise ValueError(f"Table already exists: {destination.__str__()}")
except google.api_core.exceptions.NotFound:
table_exists = False

if len(cluster_cols) != 0:
if table_exists and table.clustering_fields != cluster_cols:
raise ValueError(
"Table clustering fields cannot be changed after the table has "
f"been created. Existing clustering fields: {table.clustering_fields}"
)

sql = self.to_sql(array_value, ordered=False)
job_config = bigquery.QueryJobConfig(
write_disposition=dispositions[if_exists],
destination=destination,
clustering_fields=cluster_cols if cluster_cols else None,
)
if table_exists and _if_schama_match(table.schema, array_value.schema):
# b/409086472: Uses DML for table appends and replacements to avoid
# BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits:
# https://cloud.google.com/bigquery/quotas#standard_tables
job_config = bigquery.QueryJobConfig()
ir = sqlglot_ir.SQLGlotIR.from_query_string(sql)
if if_exists == "append":
sql = ir.insert(destination)
else: # for "replace"
assert if_exists == "replace"
sql = ir.replace(destination)
else:
dispositions = {
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
"append": bigquery.WriteDisposition.WRITE_APPEND,
}
job_config = bigquery.QueryJobConfig(
write_disposition=dispositions[if_exists],
destination=destination,
clustering_fields=cluster_cols if cluster_cols else None,
)

# TODO(swast): plumb through the api_name of the user-facing api that
# caused this query.
_, query_job = self._run_execute_query(
Expand Down Expand Up @@ -572,6 +602,21 @@ def _execute_plan(
)


def _if_schama_match(
table_schema: Tuple[bigquery.SchemaField, ...], schema: schemata.ArraySchema
) -> bool:
if len(table_schema) != len(schema.items):
return False
for field in table_schema:
if field.name not in schema.names:
return False
if bigframes.dtypes.convert_schema_field(field)[1] != schema.get_type(
field.name
):
return False
return True


def _sanitize(
schema: Tuple[bigquery.SchemaField, ...]
) -> Tuple[bigquery.SchemaField, ...]:
Expand Down
122 changes: 81 additions & 41 deletions tests/system/small/test_dataframe_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_to_csv_tabs(
[True, False],
)
@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq")
def test_to_gbq_index(scalars_dfs, dataset_id, index):
def test_to_gbq_w_index(scalars_dfs, dataset_id, index):
"""Test the `to_gbq` API with the `index` parameter."""
scalars_df, scalars_pandas_df = scalars_dfs
destination_table = f"{dataset_id}.test_index_df_to_gbq_{index}"
Expand All @@ -485,48 +485,67 @@ def test_to_gbq_index(scalars_dfs, dataset_id, index):
pd.testing.assert_frame_equal(df_out, expected, check_index_type=False)


@pytest.mark.parametrize(
("if_exists", "expected_index"),
[
pytest.param("replace", 1),
pytest.param("append", 2),
pytest.param(
"fail",
0,
marks=pytest.mark.xfail(
raises=google.api_core.exceptions.Conflict,
),
),
pytest.param(
"unknown",
0,
marks=pytest.mark.xfail(
raises=ValueError,
),
),
],
)
@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq")
def test_to_gbq_if_exists(
scalars_df_default_index,
scalars_pandas_df_default_index,
dataset_id,
if_exists,
expected_index,
):
"""Test the `to_gbq` API with the `if_exists` parameter."""
destination_table = f"{dataset_id}.test_to_gbq_if_exists_{if_exists}"
def test_to_gbq_if_exists_is_fail(scalars_dfs, dataset_id):
scalars_df, scalars_pandas_df = scalars_dfs
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_fails"
scalars_df.to_gbq(destination_table)

scalars_df_default_index.to_gbq(destination_table)
scalars_df_default_index.to_gbq(destination_table, if_exists=if_exists)
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == len(scalars_pandas_df)
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)

gcs_df = pd.read_gbq(destination_table)
assert len(gcs_df.index) == expected_index * len(
scalars_pandas_df_default_index.index
)
pd.testing.assert_index_equal(
gcs_df.columns, scalars_pandas_df_default_index.columns
)
# Test default value is "fails"
with pytest.raises(ValueError, match="Table already exists"):
scalars_df.to_gbq(destination_table)

with pytest.raises(ValueError, match="Table already exists"):
scalars_df.to_gbq(destination_table, if_exists="fail")


def test_to_gbq_if_exists_is_replace(scalars_dfs, dataset_id):
scalars_df, scalars_pandas_df = scalars_dfs
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_replace"
scalars_df.to_gbq(destination_table)

gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == len(scalars_pandas_df)
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)

# When replacing a table with same schema
scalars_df.to_gbq(destination_table, if_exists="replace")
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == len(scalars_pandas_df)
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)

# When replacing a table with different schema
partitial_scalars_df = scalars_df.drop(columns=["string_col"])
partitial_scalars_df.to_gbq(destination_table, if_exists="replace")
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == len(partitial_scalars_df)
pd.testing.assert_index_equal(gcs_df.columns, partitial_scalars_df.columns)


def test_to_gbq_if_exists_is_append(scalars_dfs, dataset_id):
scalars_df, scalars_pandas_df = scalars_dfs
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_append"
scalars_df.to_gbq(destination_table)

gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == len(scalars_pandas_df)
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)

# When appending to a table with same schema
scalars_df.to_gbq(destination_table, if_exists="append")
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == 2 * len(scalars_pandas_df)
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)

# When appending to a table with different schema
partitial_scalars_df = scalars_df.drop(columns=["string_col"])
partitial_scalars_df.to_gbq(destination_table, if_exists="append")
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
assert len(gcs_df) == 3 * len(partitial_scalars_df)
pd.testing.assert_index_equal(gcs_df.columns, scalars_df.columns)


def test_to_gbq_w_duplicate_column_names(
Expand Down Expand Up @@ -773,6 +792,27 @@ def test_to_gbq_w_clustering_no_destination(
assert table.expires is not None


def test_to_gbq_w_clustering_existing_table(
scalars_df_default_index,
dataset_id,
bigquery_client,
):
destination_table = f"{dataset_id}.test_to_gbq_w_clustering_existing_table"
scalars_df_default_index.to_gbq(destination_table)

table = bigquery_client.get_table(destination_table)
assert table.clustering_fields is None
assert table.expires is None

with pytest.raises(ValueError, match="Table clustering fields cannot be changed"):
clustering_columns = ["int64_col"]
scalars_df_default_index.to_gbq(
destination_table,
if_exists="replace",
clustering_columns=clustering_columns,
)


def test_to_gbq_w_invalid_destination_table(scalars_df_index):
with pytest.raises(ValueError):
scalars_df_index.to_gbq("table_id")
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_dataframe_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ def test_dataframe_to_pandas(mock_df, api_name, kwargs):
mock_df.to_pandas.assert_called_once_with(
allow_large_results=kwargs["allow_large_results"]
)


def test_to_gbq_if_exists_invalid(mock_df):
with pytest.raises(ValueError, match="Got invalid value 'invalid' for if_exists."):
mock_df.to_gbq("a.b.c", if_exists="invalid")