2121
2222import csv
2323import json
24+ import os
2425from collections .abc import Sequence
2526from functools import cached_property
27+ from tempfile import NamedTemporaryFile
2628from typing import TYPE_CHECKING , Any , ClassVar
29+ from urllib .parse import urlparse
2730
2831from databricks .sql .utils import ParamEscaper
2932
30- from airflow .providers .common .compat .sdk import AirflowException , BaseOperator
33+ from airflow .providers .common .compat .sdk import (
34+ AirflowException ,
35+ AirflowOptionalProviderFeatureException ,
36+ BaseOperator ,
37+ )
3138from airflow .providers .common .sql .operators .sql import SQLExecuteQueryOperator
3239from airflow .providers .databricks .hooks .databricks_sql import DatabricksSqlHook
3340
@@ -62,13 +69,27 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
6269 :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ (templated)
6370 :param schema: An optional initial schema to use. Requires DBR version 9.0+ (templated)
6471 :param output_path: optional string specifying the file to which write selected data. (templated)
65- :param output_format: format of output data if ``output_path` is specified.
66- Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
72+ Supports local file paths and GCS URIs (e.g., ``gs://bucket/path/file.parquet``).
73+ When using GCS URIs, requires the ``apache-airflow-providers-google`` package.
74+ :param output_format: format of output data if ``output_path`` is specified.
75+ Possible values are ``csv``, ``json``, ``jsonl``, ``parquet``, ``avro``. Default is ``csv``.
6776 :param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data.
77+ :param gcp_conn_id: The connection ID to use for connecting to Google Cloud when using GCS output path.
78+ Default is ``google_cloud_default``.
79+ :param gcs_impersonation_chain: Optional service account to impersonate using short-term
80+ credentials for GCS upload, or chained list of accounts required to get the access_token
81+ of the last account in the list, which will be impersonated in the request. (templated)
6882 """
6983
7084 template_fields : Sequence [str ] = tuple (
71- {"_output_path" , "schema" , "catalog" , "http_headers" , "databricks_conn_id" }
85+ {
86+ "_output_path" ,
87+ "schema" ,
88+ "catalog" ,
89+ "http_headers" ,
90+ "databricks_conn_id" ,
91+ "_gcs_impersonation_chain" ,
92+ }
7293 | set (SQLExecuteQueryOperator .template_fields )
7394 )
7495
@@ -90,6 +111,8 @@ def __init__(
90111 output_format : str = "csv" ,
91112 csv_params : dict [str , Any ] | None = None ,
92113 client_parameters : dict [str , Any ] | None = None ,
114+ gcp_conn_id : str = "google_cloud_default" ,
115+ gcs_impersonation_chain : str | Sequence [str ] | None = None ,
93116 ** kwargs ,
94117 ) -> None :
95118 super ().__init__ (conn_id = databricks_conn_id , ** kwargs )
@@ -105,6 +128,8 @@ def __init__(
105128 self .http_headers = http_headers
106129 self .catalog = catalog
107130 self .schema = schema
131+ self ._gcp_conn_id = gcp_conn_id
132+ self ._gcs_impersonation_chain = gcs_impersonation_chain
108133
109134 @cached_property
110135 def _hook (self ) -> DatabricksSqlHook :
@@ -127,41 +152,151 @@ def get_db_hook(self) -> DatabricksSqlHook:
127152 def _should_run_output_processing (self ) -> bool :
128153 return self .do_xcom_push or bool (self ._output_path )
129154
155+ @property
156+ def _is_gcs_output (self ) -> bool :
157+ """Check if the output path is a GCS URI."""
158+ return self ._output_path .startswith ("gs://" ) if self ._output_path else False
159+
160+ def _parse_gcs_path (self , path : str ) -> tuple [str , str ]:
161+ """Parse a GCS URI into bucket and object name."""
162+ parsed = urlparse (path )
163+ bucket = parsed .netloc
164+ object_name = parsed .path .lstrip ("/" )
165+ return bucket , object_name
166+
167+ def _upload_to_gcs (self , local_path : str , gcs_path : str ) -> None :
168+ """Upload a local file to GCS."""
169+ try :
170+ from airflow .providers .google .cloud .hooks .gcs import GCSHook
171+ except ImportError :
172+ raise AirflowOptionalProviderFeatureException (
173+ "The 'apache-airflow-providers-google' package is required for GCS output. "
174+ "Install it with: pip install apache-airflow-providers-google"
175+ )
176+
177+ bucket , object_name = self ._parse_gcs_path (gcs_path )
178+ hook = GCSHook (
179+ gcp_conn_id = self ._gcp_conn_id ,
180+ impersonation_chain = self ._gcs_impersonation_chain ,
181+ )
182+ hook .upload (
183+ bucket_name = bucket ,
184+ object_name = object_name ,
185+ filename = local_path ,
186+ )
187+ self .log .info ("Uploaded output to %s" , gcs_path )
188+
189+ def _write_parquet (self , file_path : str , field_names : list [str ], rows : list [Any ]) -> None :
190+ """Write data to a Parquet file."""
191+ import pyarrow as pa
192+ import pyarrow .parquet as pq
193+
194+ data : dict [str , list ] = {name : [] for name in field_names }
195+ for row in rows :
196+ row_dict = row ._asdict ()
197+ for name in field_names :
198+ data [name ].append (row_dict [name ])
199+
200+ table = pa .Table .from_pydict (data )
201+ pq .write_table (table , file_path )
202+
203+ def _write_avro (self , file_path : str , field_names : list [str ], rows : list [Any ]) -> None :
204+ """Write data to an Avro file using fastavro."""
205+ try :
206+ from fastavro import writer
207+ except ImportError :
208+ raise AirflowOptionalProviderFeatureException (
209+ "The 'fastavro' package is required for Avro output. Install it with: pip install fastavro"
210+ )
211+
212+ data : dict [str , list ] = {name : [] for name in field_names }
213+ for row in rows :
214+ row_dict = row ._asdict ()
215+ for name in field_names :
216+ data [name ].append (row_dict [name ])
217+
218+ schema_fields = []
219+ for name in field_names :
220+ sample_val = next (
221+ (data [name ][i ] for i in range (len (data [name ])) if data [name ][i ] is not None ), None
222+ )
223+ if sample_val is None :
224+ avro_type = ["null" , "string" ]
225+ elif isinstance (sample_val , bool ):
226+ avro_type = ["null" , "boolean" ]
227+ elif isinstance (sample_val , int ):
228+ avro_type = ["null" , "long" ]
229+ elif isinstance (sample_val , float ):
230+ avro_type = ["null" , "double" ]
231+ else :
232+ avro_type = ["null" , "string" ]
233+ schema_fields .append ({"name" : name , "type" : avro_type })
234+
235+ avro_schema = {
236+ "type" : "record" ,
237+ "name" : "QueryResult" ,
238+ "fields" : schema_fields ,
239+ }
240+
241+ records = [row ._asdict () for row in rows ]
242+ with open (file_path , "wb" ) as f :
243+ writer (f , avro_schema , records )
244+
130245 def _process_output (self , results : list [Any ], descriptions : list [Sequence [Sequence ] | None ]) -> list [Any ]:
131246 if not self ._output_path :
132247 return list (zip (descriptions , results ))
133248 if not self ._output_format :
134249 raise AirflowException ("Output format should be specified!" )
135- # Output to a file only the result of last query
250+
136251 last_description = descriptions [- 1 ]
137252 last_results = results [- 1 ]
138253 if last_description is None :
139- raise AirflowException ("There is missing description present for the output file. . " )
254+ raise AirflowException ("There is missing description present for the output file." )
140255 field_names = [field [0 ] for field in last_description ]
141- if self ._output_format .lower () == "csv" :
142- with open (self ._output_path , "w" , newline = "" ) as file :
143- if self ._csv_params :
144- csv_params = self ._csv_params
145- else :
146- csv_params = {}
147- write_header = csv_params .get ("header" , True )
148- if "header" in csv_params :
149- del csv_params ["header" ]
150- writer = csv .DictWriter (file , fieldnames = field_names , ** csv_params )
151- if write_header :
152- writer .writeheader ()
153- for row in last_results :
154- writer .writerow (row ._asdict ())
155- elif self ._output_format .lower () == "json" :
156- with open (self ._output_path , "w" ) as file :
157- file .write (json .dumps ([row ._asdict () for row in last_results ]))
158- elif self ._output_format .lower () == "jsonl" :
159- with open (self ._output_path , "w" ) as file :
160- for row in last_results :
161- file .write (json .dumps (row ._asdict ()))
162- file .write ("\n " )
256+
257+ if self ._is_gcs_output :
258+ suffix = f".{ self ._output_format .lower ()} "
259+ tmp_file = NamedTemporaryFile (mode = "w" , suffix = suffix , delete = False , newline = "" )
260+ local_path = tmp_file .name
261+ tmp_file .close ()
163262 else :
164- raise AirflowException (f"Unsupported output format: '{ self ._output_format } '" )
263+ local_path = self ._output_path
264+
265+ try :
266+ output_format = self ._output_format .lower ()
267+ if output_format == "csv" :
268+ with open (local_path , "w" , newline = "" ) as file :
269+ if self ._csv_params :
270+ csv_params = self ._csv_params .copy ()
271+ else :
272+ csv_params = {}
273+ write_header = csv_params .pop ("header" , True )
274+ writer = csv .DictWriter (file , fieldnames = field_names , ** csv_params )
275+ if write_header :
276+ writer .writeheader ()
277+ for row in last_results :
278+ writer .writerow (row ._asdict ())
279+ elif output_format == "json" :
280+ with open (local_path , "w" ) as file :
281+ file .write (json .dumps ([row ._asdict () for row in last_results ]))
282+ elif output_format == "jsonl" :
283+ with open (local_path , "w" ) as file :
284+ for row in last_results :
285+ file .write (json .dumps (row ._asdict ()))
286+ file .write ("\n " )
287+ elif output_format == "parquet" :
288+ self ._write_parquet (local_path , field_names , last_results )
289+ elif output_format == "avro" :
290+ self ._write_avro (local_path , field_names , last_results )
291+ else :
292+ raise ValueError (f"Unsupported output format: '{ self ._output_format } '" )
293+
294+ if self ._is_gcs_output :
295+ self ._upload_to_gcs (local_path , self ._output_path )
296+ finally :
297+ if self ._is_gcs_output and os .path .exists (local_path ):
298+ os .unlink (local_path )
299+
165300 return list (zip (descriptions , results ))
166301
167302
0 commit comments