Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e8d44ed

Browse files
haseebmalik18Ankurdeewan
authored andcommitted
Add direct GCS export to DatabricksSqlOperator with Parquet/Avro support apache#55128 (apache#60543)
1 parent b073b02 commit e8d44ed

5 files changed

Lines changed: 315 additions & 33 deletions

File tree

dev/breeze/tests/test_selective_checks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,7 +1948,7 @@ def test_expected_output_push(
19481948
),
19491949
{
19501950
"selected-providers-list-as-string": "amazon apache.beam apache.cassandra apache.kafka "
1951-
"cncf.kubernetes common.compat common.sql "
1951+
"cncf.kubernetes common.compat common.sql databricks "
19521952
"facebook google hashicorp http microsoft.azure microsoft.mssql mysql "
19531953
"openlineage oracle postgres presto salesforce samba sftp ssh standard trino",
19541954
"all-python-versions": f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
@@ -1960,7 +1960,7 @@ def test_expected_output_push(
19601960
"skip-providers-tests": "false",
19611961
"docs-build": "true",
19621962
"docs-list-as-string": "apache-airflow helm-chart amazon apache.beam apache.cassandra "
1963-
"apache.kafka cncf.kubernetes common.compat common.sql facebook google hashicorp http microsoft.azure "
1963+
"apache.kafka cncf.kubernetes common.compat common.sql databricks facebook google hashicorp http microsoft.azure "
19641964
"microsoft.mssql mysql openlineage oracle postgres "
19651965
"presto salesforce samba sftp ssh standard trino",
19661966
"skip-prek-hooks": ALL_SKIPPED_COMMITS_IF_NO_UI,
@@ -1974,7 +1974,7 @@ def test_expected_output_push(
19741974
{
19751975
"description": "amazon...standard",
19761976
"test_types": "Providers[amazon] Providers[apache.beam,apache.cassandra,"
1977-
"apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
1977+
"apache.kafka,cncf.kubernetes,common.compat,common.sql,databricks,facebook,"
19781978
"hashicorp,http,microsoft.azure,microsoft.mssql,mysql,"
19791979
"openlineage,oracle,postgres,presto,salesforce,samba,sftp,ssh,trino] "
19801980
"Providers[google] "
@@ -2245,7 +2245,7 @@ def test_upgrade_to_newer_dependencies(
22452245
("providers/google/docs/some_file.rst",),
22462246
{
22472247
"docs-list-as-string": "amazon apache.beam apache.cassandra apache.kafka "
2248-
"cncf.kubernetes common.compat common.sql facebook google hashicorp http "
2248+
"cncf.kubernetes common.compat common.sql databricks facebook google hashicorp http "
22492249
"microsoft.azure microsoft.mssql mysql openlineage oracle "
22502250
"postgres presto salesforce samba sftp ssh standard trino",
22512251
},

providers/databricks/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Dependent package
132132
================================================================================================================== =================
133133
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
134134
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
135+
`apache-airflow-providers-google <https://airflow.apache.org/docs/apache-airflow-providers-google>`_ ``google``
135136
`apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
136137
================================================================================================================== =================
137138

providers/databricks/pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ dependencies = [
9393
"sqlalchemy" = [
9494
"databricks-sqlalchemy>=1.0.2",
9595
]
96+
"google" = [
97+
"apache-airflow-providers-google>=10.24.0"
98+
]
99+
"avro" = [
100+
"fastavro>=1.9.0"
101+
]
96102

97103
[dependency-groups]
98104
dev = [
@@ -101,6 +107,7 @@ dev = [
101107
"apache-airflow-devel-common",
102108
"apache-airflow-providers-common-compat",
103109
"apache-airflow-providers-common-sql",
110+
"apache-airflow-providers-google",
104111
"apache-airflow-providers-openlineage",
105112
# Additional devel dependencies (do not remove this line and add extra development dependencies)
106113
# Need to exclude 1.3.0 due to missing aarch64 binaries, fixed with 1.3.1++

providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py

Lines changed: 164 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@
2121

2222
import csv
2323
import json
24+
import os
2425
from collections.abc import Sequence
2526
from functools import cached_property
27+
from tempfile import NamedTemporaryFile
2628
from typing import TYPE_CHECKING, Any, ClassVar
29+
from urllib.parse import urlparse
2730

2831
from databricks.sql.utils import ParamEscaper
2932

30-
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
33+
from airflow.providers.common.compat.sdk import (
34+
AirflowException,
35+
AirflowOptionalProviderFeatureException,
36+
BaseOperator,
37+
)
3138
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
3239
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
3340

@@ -62,13 +69,27 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
6269
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+ (templated)
6370
:param schema: An optional initial schema to use. Requires DBR version 9.0+ (templated)
6471
:param output_path: optional string specifying the file to which write selected data. (templated)
65-
:param output_format: format of output data if ``output_path` is specified.
66-
Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
72+
Supports local file paths and GCS URIs (e.g., ``gs://bucket/path/file.parquet``).
73+
When using GCS URIs, requires the ``apache-airflow-providers-google`` package.
74+
:param output_format: format of output data if ``output_path`` is specified.
75+
Possible values are ``csv``, ``json``, ``jsonl``, ``parquet``, ``avro``. Default is ``csv``.
6776
:param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data.
77+
:param gcp_conn_id: The connection ID to use for connecting to Google Cloud when using GCS output path.
78+
Default is ``google_cloud_default``.
79+
:param gcs_impersonation_chain: Optional service account to impersonate using short-term
80+
credentials for GCS upload, or chained list of accounts required to get the access_token
81+
of the last account in the list, which will be impersonated in the request. (templated)
6882
"""
6983

7084
template_fields: Sequence[str] = tuple(
71-
{"_output_path", "schema", "catalog", "http_headers", "databricks_conn_id"}
85+
{
86+
"_output_path",
87+
"schema",
88+
"catalog",
89+
"http_headers",
90+
"databricks_conn_id",
91+
"_gcs_impersonation_chain",
92+
}
7293
| set(SQLExecuteQueryOperator.template_fields)
7394
)
7495

@@ -90,6 +111,8 @@ def __init__(
90111
output_format: str = "csv",
91112
csv_params: dict[str, Any] | None = None,
92113
client_parameters: dict[str, Any] | None = None,
114+
gcp_conn_id: str = "google_cloud_default",
115+
gcs_impersonation_chain: str | Sequence[str] | None = None,
93116
**kwargs,
94117
) -> None:
95118
super().__init__(conn_id=databricks_conn_id, **kwargs)
@@ -105,6 +128,8 @@ def __init__(
105128
self.http_headers = http_headers
106129
self.catalog = catalog
107130
self.schema = schema
131+
self._gcp_conn_id = gcp_conn_id
132+
self._gcs_impersonation_chain = gcs_impersonation_chain
108133

109134
@cached_property
110135
def _hook(self) -> DatabricksSqlHook:
@@ -127,41 +152,151 @@ def get_db_hook(self) -> DatabricksSqlHook:
127152
def _should_run_output_processing(self) -> bool:
128153
return self.do_xcom_push or bool(self._output_path)
129154

155+
@property
156+
def _is_gcs_output(self) -> bool:
157+
"""Check if the output path is a GCS URI."""
158+
return self._output_path.startswith("gs://") if self._output_path else False
159+
160+
def _parse_gcs_path(self, path: str) -> tuple[str, str]:
161+
"""Parse a GCS URI into bucket and object name."""
162+
parsed = urlparse(path)
163+
bucket = parsed.netloc
164+
object_name = parsed.path.lstrip("/")
165+
return bucket, object_name
166+
167+
def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
168+
"""Upload a local file to GCS."""
169+
try:
170+
from airflow.providers.google.cloud.hooks.gcs import GCSHook
171+
except ImportError:
172+
raise AirflowOptionalProviderFeatureException(
173+
"The 'apache-airflow-providers-google' package is required for GCS output. "
174+
"Install it with: pip install apache-airflow-providers-google"
175+
)
176+
177+
bucket, object_name = self._parse_gcs_path(gcs_path)
178+
hook = GCSHook(
179+
gcp_conn_id=self._gcp_conn_id,
180+
impersonation_chain=self._gcs_impersonation_chain,
181+
)
182+
hook.upload(
183+
bucket_name=bucket,
184+
object_name=object_name,
185+
filename=local_path,
186+
)
187+
self.log.info("Uploaded output to %s", gcs_path)
188+
189+
def _write_parquet(self, file_path: str, field_names: list[str], rows: list[Any]) -> None:
190+
"""Write data to a Parquet file."""
191+
import pyarrow as pa
192+
import pyarrow.parquet as pq
193+
194+
data: dict[str, list] = {name: [] for name in field_names}
195+
for row in rows:
196+
row_dict = row._asdict()
197+
for name in field_names:
198+
data[name].append(row_dict[name])
199+
200+
table = pa.Table.from_pydict(data)
201+
pq.write_table(table, file_path)
202+
203+
def _write_avro(self, file_path: str, field_names: list[str], rows: list[Any]) -> None:
204+
"""Write data to an Avro file using fastavro."""
205+
try:
206+
from fastavro import writer
207+
except ImportError:
208+
raise AirflowOptionalProviderFeatureException(
209+
"The 'fastavro' package is required for Avro output. Install it with: pip install fastavro"
210+
)
211+
212+
data: dict[str, list] = {name: [] for name in field_names}
213+
for row in rows:
214+
row_dict = row._asdict()
215+
for name in field_names:
216+
data[name].append(row_dict[name])
217+
218+
schema_fields = []
219+
for name in field_names:
220+
sample_val = next(
221+
(data[name][i] for i in range(len(data[name])) if data[name][i] is not None), None
222+
)
223+
if sample_val is None:
224+
avro_type = ["null", "string"]
225+
elif isinstance(sample_val, bool):
226+
avro_type = ["null", "boolean"]
227+
elif isinstance(sample_val, int):
228+
avro_type = ["null", "long"]
229+
elif isinstance(sample_val, float):
230+
avro_type = ["null", "double"]
231+
else:
232+
avro_type = ["null", "string"]
233+
schema_fields.append({"name": name, "type": avro_type})
234+
235+
avro_schema = {
236+
"type": "record",
237+
"name": "QueryResult",
238+
"fields": schema_fields,
239+
}
240+
241+
records = [row._asdict() for row in rows]
242+
with open(file_path, "wb") as f:
243+
writer(f, avro_schema, records)
244+
130245
def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
131246
if not self._output_path:
132247
return list(zip(descriptions, results))
133248
if not self._output_format:
134249
raise AirflowException("Output format should be specified!")
135-
# Output to a file only the result of last query
250+
136251
last_description = descriptions[-1]
137252
last_results = results[-1]
138253
if last_description is None:
139-
raise AirflowException("There is missing description present for the output file. .")
254+
raise AirflowException("There is missing description present for the output file.")
140255
field_names = [field[0] for field in last_description]
141-
if self._output_format.lower() == "csv":
142-
with open(self._output_path, "w", newline="") as file:
143-
if self._csv_params:
144-
csv_params = self._csv_params
145-
else:
146-
csv_params = {}
147-
write_header = csv_params.get("header", True)
148-
if "header" in csv_params:
149-
del csv_params["header"]
150-
writer = csv.DictWriter(file, fieldnames=field_names, **csv_params)
151-
if write_header:
152-
writer.writeheader()
153-
for row in last_results:
154-
writer.writerow(row._asdict())
155-
elif self._output_format.lower() == "json":
156-
with open(self._output_path, "w") as file:
157-
file.write(json.dumps([row._asdict() for row in last_results]))
158-
elif self._output_format.lower() == "jsonl":
159-
with open(self._output_path, "w") as file:
160-
for row in last_results:
161-
file.write(json.dumps(row._asdict()))
162-
file.write("\n")
256+
257+
if self._is_gcs_output:
258+
suffix = f".{self._output_format.lower()}"
259+
tmp_file = NamedTemporaryFile(mode="w", suffix=suffix, delete=False, newline="")
260+
local_path = tmp_file.name
261+
tmp_file.close()
163262
else:
164-
raise AirflowException(f"Unsupported output format: '{self._output_format}'")
263+
local_path = self._output_path
264+
265+
try:
266+
output_format = self._output_format.lower()
267+
if output_format == "csv":
268+
with open(local_path, "w", newline="") as file:
269+
if self._csv_params:
270+
csv_params = self._csv_params.copy()
271+
else:
272+
csv_params = {}
273+
write_header = csv_params.pop("header", True)
274+
writer = csv.DictWriter(file, fieldnames=field_names, **csv_params)
275+
if write_header:
276+
writer.writeheader()
277+
for row in last_results:
278+
writer.writerow(row._asdict())
279+
elif output_format == "json":
280+
with open(local_path, "w") as file:
281+
file.write(json.dumps([row._asdict() for row in last_results]))
282+
elif output_format == "jsonl":
283+
with open(local_path, "w") as file:
284+
for row in last_results:
285+
file.write(json.dumps(row._asdict()))
286+
file.write("\n")
287+
elif output_format == "parquet":
288+
self._write_parquet(local_path, field_names, last_results)
289+
elif output_format == "avro":
290+
self._write_avro(local_path, field_names, last_results)
291+
else:
292+
raise ValueError(f"Unsupported output format: '{self._output_format}'")
293+
294+
if self._is_gcs_output:
295+
self._upload_to_gcs(local_path, self._output_path)
296+
finally:
297+
if self._is_gcs_output and os.path.exists(local_path):
298+
os.unlink(local_path)
299+
165300
return list(zip(descriptions, results))
166301

167302

0 commit comments

Comments
 (0)