diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a776377c3..3c0e325fe 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -51,12 +51,20 @@ def test_sea_async_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 5000 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + f"Executing asynchronous query with cloud fetch to generate {requested_row_count} rows" ) - cursor.execute_async("SELECT 1 as test_value") + cursor.execute_async(query) logger.info( "Asynchronous query submitted successfully with cloud fetch enabled" ) @@ -69,8 +77,25 @@ def test_sea_async_query_with_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" + "PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly" ) # Close resources @@ -130,12 +155,20 @@ def test_sea_async_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 100)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + f"Executing asynchronous query without cloud fetch to generate {requested_row_count} rows" ) - cursor.execute_async("SELECT 1 as test_value") + cursor.execute_async(query) logger.info( "Asynchronous query submitted successfully with cloud fetch disabled" ) @@ -148,8 +181,24 @@ def test_sea_async_query_without_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" + "PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly" ) # Close resources diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 07be8aafc..76941e2d2 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -49,13 +49,27 @@ def test_sea_sync_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 10000 cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute(query) + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + actual_row_count = len(results) logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + f"{actual_row_count} rows retrieved against {requested_row_count} requested" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") # Close resources cursor.close() @@ -114,13 +128,18 @@ def test_sea_sync_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows") + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") # Close resources cursor.close() diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 0c0400ae2..814859a31 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -251,7 +251,7 @@ def close_session(self, session_id: SessionId) -> None: logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) if session_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA session ID") + raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() request_data = DeleteSessionRequest( @@ -290,7 +290,7 @@ def get_allowed_session_configurations() -> List[str]: def _extract_description_from_manifest( self, manifest: ResultManifest - ) -> Optional[List]: + ) -> List[Tuple]: """ Extract column description from a manifest object, in the format defined by the spec: https://peps.python.org/pep-0249/#description @@ -299,15 +299,12 @@ def _extract_description_from_manifest( manifest: The ResultManifest object containing schema information Returns: - Optional[List]: A list of column tuples or None if no columns are found + List[Tuple]: A list of column tuples """ schema_data = manifest.schema columns_data = schema_data.get("columns", []) - if not columns_data: - return None - columns = [] for col_data in columns_data: # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) @@ -323,7 +320,7 @@ def _extract_description_from_manifest( ) ) - return columns if columns else None + return columns def _results_message_to_execute_response( self, response: GetStatementResponse @@ -429,7 +426,7 @@ def execute_command( """ if session_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA session ID") + raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -508,9 +505,11 @@ def cancel_command(self, command_id: CommandId) -> None: """ if command_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ValueError("Not a valid SEA command ID") request = CancelStatementRequest(statement_id=sea_statement_id) self.http_client._make_request( @@ -531,9 +530,11 @@ def close_command(self, command_id: CommandId) -> None: """ if command_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ValueError("Not a valid SEA command ID") request = CloseStatementRequest(statement_id=sea_statement_id) self.http_client._make_request( @@ -560,6 +561,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState: raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ValueError("Not a valid SEA command ID") request = GetStatementRequest(statement_id=sea_statement_id) response_data = self.http_client._make_request( @@ -592,9 +595,11 @@ def get_execution_result( """ if command_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ValueError("Not a valid SEA command ID") # Create the request model request = GetStatementRequest(statement_id=sea_statement_id) @@ -608,7 +613,7 @@ def get_execution_result( response = GetStatementResponse.from_dict(response_data) # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet execute_response = self._results_message_to_execute_response(response) @@ -616,10 +621,10 @@ def get_execution_result( connection=cursor.connection, execute_response=execute_response, sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, result_data=response.result, manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py new file mode 100644 index 000000000..73f47ea96 --- /dev/null +++ b/src/databricks/sql/backend/sea/queue.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from abc import ABC +from typing import List, Optional, Tuple + +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.constants import ResultFormat +from databricks.sql.exc import ProgrammingError +from databricks.sql.utils import ResultSetQueue + + +class SeaResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + sea_result_data: ResultData, + manifest: ResultManifest, + statement_id: str, + description: List[Tuple] = [], + max_download_threads: Optional[int] = None, + sea_client: Optional[SeaDatabricksClient] = None, + lz4_compressed: bool = False, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue for SEA backend. + + Args: + sea_result_data (ResultData): Result data from SEA response + manifest (ResultManifest): Manifest from SEA response + statement_id (str): Statement ID for the query + description (List[List[Any]]): Column descriptions + max_download_threads (int): Maximum number of download threads + sea_client (SeaDatabricksClient): SEA client for fetching additional links + lz4_compressed (bool): Whether the data is LZ4 compressed + + Returns: + ResultSetQueue: The appropriate queue for the result data + """ + + if manifest.format == ResultFormat.JSON_ARRAY.value: + # INLINE disposition with JSON_ARRAY format + return JsonQueue(sea_result_data.data) + elif manifest.format == ResultFormat.ARROW_STREAM.value: + # EXTERNAL_LINKS disposition + raise NotImplementedError( + "EXTERNAL_LINKS disposition is not implemented for SEA backend" + ) + raise ProgrammingError("Invalid result format") + + +class JsonQueue(ResultSetQueue): + """Queue implementation for JSON_ARRAY format data.""" + + def __init__(self, data_array: Optional[List[List[str]]]): + """Initialize with JSON array data.""" + self.data_array = data_array or [] + self.cur_row_index = 0 + self.num_rows = len(self.data_array) + + def next_n_rows(self, num_rows: int) -> List[List[str]]: + """Get the next n rows from the data array.""" + length = min(num_rows, self.num_rows - self.cur_row_index) + slice = self.data_array[self.cur_row_index : self.cur_row_index + length] + self.cur_row_index += length + return slice + + def remaining_rows(self) -> List[List[str]]: + """Get all remaining rows from the data array.""" + slice = self.data_array[self.cur_row_index :] + self.cur_row_index += len(slice) + return slice diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py new file mode 100644 index 000000000..302af5e3a --- /dev/null +++ b/src/databricks/sql/backend/sea/result_set.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from typing import Any, List, Optional, TYPE_CHECKING + +import logging + +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.client import Connection +from databricks.sql.exc import ProgrammingError +from databricks.sql.types import Row +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.result_set import ResultSet + +logger = logging.getLogger(__name__) + + +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection: Connection, + execute_response: ExecuteResponse, + sea_client: SeaDatabricksClient, + result_data: ResultData, + manifest: ResultManifest, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + execute_response: Response from the execute command + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + result_data: Result data from SEA response + manifest: Manifest from SEA response + """ + + self.manifest = manifest + + statement_id = execute_response.command_id.to_sea_statement_id() + if statement_id is None: + raise ValueError("Command ID is not a SEA statement ID") + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + description=execute_response.description, + max_download_threads=sea_client.max_download_threads, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + results_queue=results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, + ) + + def _convert_json_types(self, row: List[str]) -> List[Any]: + """ + Convert string values in the row to appropriate Python types based on column metadata. + """ + + # JSON + INLINE gives us string values, so we convert them to appropriate + # types based on column metadata + converted_row = [] + + for i, value in enumerate(row): + column_type = self.description[i][1] + precision = self.description[i][4] + scale = self.description[i][5] + + try: + converted_value = SqlTypeConverter.convert_value( + value, column_type, precision=precision, scale=scale + ) + converted_row.append(converted_value) + except Exception as e: + logger.warning( + f"Error converting value '{value}' to {column_type}: {e}" + ) + converted_row.append(value) + + return converted_row + + def _convert_json_to_arrow_table(self, rows: List[List[str]]) -> "pyarrow.Table": + """ + Convert raw data rows to Arrow table. + + Args: + rows: List of raw data rows + + Returns: + PyArrow Table containing the converted values + """ + + if not rows: + return pyarrow.Table.from_pydict({}) + + # create a generator for row conversion + converted_rows_iter = (self._convert_json_types(row) for row in rows) + cols = list(map(list, zip(*converted_rows_iter))) + + names = [col[0] for col in self.description] + return pyarrow.Table.from_arrays(cols, names=names) + + def _create_json_table(self, rows: List[List[str]]) -> List[Row]: + """ + Convert raw data rows to Row objects with named columns based on description. + + Args: + rows: List of raw data rows + Returns: + List of Row objects with named columns and converted values + """ + + ResultRow = Row(*[col[0] for col in self.description]) + return [ResultRow(*self._convert_json_types(row)) for row in rows] + + def fetchmany_json(self, size: int) -> List[List[str]]: + """ + Fetch the next set of rows as a columnar table. + + Args: + size: Number of rows to fetch + + Returns: + Columnar table containing the fetched rows + + Raises: + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += len(results) + + return results + + def fetchall_json(self) -> List[List[str]]: + """ + Fetch all remaining rows as a columnar table. + + Returns: + Columnar table containing all remaining rows + """ + + results = self.results.remaining_rows() + self._next_row_index += len(results) + + return results + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows as an Arrow table. + + Args: + size: Number of rows to fetch + + Returns: + PyArrow Table containing the fetched rows + + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchmany_arrow only supported for JSON data") + + results = self._convert_json_to_arrow_table(self.results.next_n_rows(size)) + self._next_row_index += results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """ + Fetch all remaining rows as an Arrow table. + """ + + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchall_arrow only supported for JSON data") + + results = self._convert_json_to_arrow_table(self.results.remaining_rows()) + self._next_row_index += results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + + Returns: + A single Row object or None if no more rows are available + """ + + if isinstance(self.results, JsonQueue): + res = self._create_json_table(self.fetchmany_json(1)) + else: + raise NotImplementedError("fetchone only supported for JSON data") + + return res[0] if res else None + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + Args: + size: Number of rows to fetch (defaults to arraysize if None) + + Returns: + List of Row objects + + Raises: + ValueError: If size is negative + """ + + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchmany_json(size)) + else: + raise NotImplementedError("fetchmany only supported for JSON data") + + def fetchall(self) -> List[Row]: + """ + Fetch all remaining rows of a query result, returning them as a list of rows. + + Returns: + List of Row objects containing all remaining rows + """ + + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchall_json()) + else: + raise NotImplementedError("fetchall only supported for JSON data") diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py new file mode 100644 index 000000000..b2de97f5d --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -0,0 +1,160 @@ +""" +Type conversion utilities for the Databricks SQL Connector. + +This module provides functionality to convert string values from SEA Inline results +to appropriate Python types based on column metadata. +""" + +import datetime +import decimal +import logging +from dateutil import parser +from typing import Callable, Dict, Optional + +logger = logging.getLogger(__name__) + + +def _convert_decimal( + value: str, precision: Optional[int] = None, scale: Optional[int] = None +) -> decimal.Decimal: + """ + Convert a string value to a decimal with optional precision and scale. + + Args: + value: The string value to convert + precision: Optional precision (total number of significant digits) for the decimal + scale: Optional scale (number of decimal places) for the decimal + + Returns: + A decimal.Decimal object with appropriate precision and scale + """ + + # First create the decimal from the string value + result = decimal.Decimal(value) + + # Apply scale (quantize to specific number of decimal places) if specified + quantizer = None + if scale is not None: + quantizer = decimal.Decimal(f'0.{"0" * scale}') + + # Apply precision (total number of significant digits) if specified + context = None + if precision is not None: + context = decimal.Context(prec=precision) + + if quantizer is not None: + result = result.quantize(quantizer, context=context) + + return result + + +class SqlType: + """ + SQL type constants + + The list of types can be found in the SEA REST API Reference: + https://docs.databricks.com/api/workspace/statementexecution/executestatement + """ + + # Numeric types + BYTE = "byte" + SHORT = "short" + INT = "int" + LONG = "long" + FLOAT = "float" + DOUBLE = "double" + DECIMAL = "decimal" + + # Boolean type + BOOLEAN = "boolean" + + # Date/Time types + DATE = "date" + TIMESTAMP = "timestamp" + INTERVAL = "interval" + + # String types + CHAR = "char" + STRING = "string" + + # Binary type + BINARY = "binary" + + # Complex types + ARRAY = "array" + MAP = "map" + STRUCT = "struct" + + # Other types + NULL = "null" + USER_DEFINED_TYPE = "user_defined_type" + + +class SqlTypeConverter: + """ + Utility class for converting SQL types to Python types. + Based on the types supported by the Databricks SDK. + """ + + # SQL type to conversion function mapping + # TODO: complex types + TYPE_MAPPING: Dict[str, Callable] = { + # Numeric types + SqlType.BYTE: lambda v: int(v), + SqlType.SHORT: lambda v: int(v), + SqlType.INT: lambda v: int(v), + SqlType.LONG: lambda v: int(v), + SqlType.FLOAT: lambda v: float(v), + SqlType.DOUBLE: lambda v: float(v), + SqlType.DECIMAL: _convert_decimal, + # Boolean type + SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), + # Date/Time types + SqlType.DATE: lambda v: datetime.date.fromisoformat(v), + SqlType.TIMESTAMP: lambda v: parser.parse(v), + SqlType.INTERVAL: lambda v: v, # Keep as string for now + # String types - no conversion needed + SqlType.CHAR: lambda v: v, + SqlType.STRING: lambda v: v, + # Binary type + SqlType.BINARY: lambda v: bytes.fromhex(v), + # Other types + SqlType.NULL: lambda v: None, + # Complex types and user-defined types return as-is + SqlType.USER_DEFINED_TYPE: lambda v: v, + } + + @staticmethod + def convert_value( + value: str, + sql_type: str, + **kwargs, + ) -> object: + """ + Convert a string value to the appropriate Python type based on SQL type. + + Args: + value: The string value to convert + sql_type: The SQL type (e.g., 'int', 'decimal') + **kwargs: Additional keyword arguments for the conversion function + + Returns: + The converted value in the appropriate Python type + """ + + sql_type = sql_type.lower().strip() + + if sql_type not in SqlTypeConverter.TYPE_MAPPING: + return value + + converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type] + try: + if sql_type == SqlType.DECIMAL: + precision = kwargs.get("precision", None) + scale = kwargs.get("scale", None) + return converter_func(value, precision, scale) + else: + return converter_func(value) + except (ValueError, TypeError, decimal.InvalidOperation) as e: + logger.warning(f"Error converting value '{value}' to {sql_type}: {e}") + return value diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 1b7660829..ef6c91d7d 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -17,7 +17,7 @@ ) if TYPE_CHECKING: - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.types import ExecuteResponse @@ -70,16 +70,20 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) from databricks.sql.backend.sea.backend import SeaDatabricksClient - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data + manifest = result_set.manifest + manifest.total_row_count = len(filtered_rows) + filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, sea_client=cast(SeaDatabricksClient, result_set.backend), + result_data=result_data, + manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, - result_data=result_data, ) return filtered_result_set diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..02d335aa4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -40,11 +40,11 @@ ) from databricks.sql.utils import ( - ResultSetQueueFactory, + ThriftResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, - ResultSetQueueFactory, + ThriftResultSetQueueFactory, convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, @@ -1232,7 +1232,7 @@ def fetch_results( ) ) - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 93bd7d525..5411af74f 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -369,7 +369,7 @@ def from_sea_statement_id(cls, statement_id: str): return cls(BackendType.SEA, statement_id) - def to_thrift_handle(self): + def to_thrift_handle(self) -> Optional[ttypes.TOperationHandle]: """ Convert this CommandId to a Thrift TOperationHandle. @@ -390,7 +390,7 @@ def to_thrift_handle(self): modifiedRowCount=self.modified_row_count, ) - def to_sea_statement_id(self): + def to_sea_statement_id(self) -> Optional[str]: """ Get the SEA statement ID string. @@ -401,7 +401,7 @@ def to_sea_statement_id(self): if self.backend_type != BackendType.SEA: return None - return self.guid + return str(self.guid) def to_hex_guid(self) -> str: """ @@ -423,7 +423,7 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: List[Tuple] has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 38b8a3c2f..8934d0d56 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,12 +1,11 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Tuple import logging -import time import pandas -from databricks.sql.backend.sea.backend import SeaDatabricksClient - try: import pyarrow except ImportError: @@ -16,10 +15,12 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row -from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.exc import RequestError, CursorAlreadyClosedError +from databricks.sql.utils import ( + ColumnTable, + ColumnQueue, +) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def __init__( has_been_closed_server_side: bool = False, is_direct_results: bool = False, results_queue=None, - description=None, + description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, arrow_schema_bytes: Optional[bytes] = None, @@ -88,6 +89,44 @@ def __iter__(self): else: break + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + @property def rownumber(self): return self._next_row_index @@ -97,12 +136,6 @@ def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" return self._is_staging_operation - # Define abstract methods that concrete implementations must implement - @abstractmethod - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - pass - @abstractmethod def fetchone(self) -> Optional[Row]: """Fetch the next row of a query result set.""" @@ -189,10 +222,10 @@ def __init__( # Build the results queue if t_row_set is provided results_queue = None if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory + from databricks.sql.utils import ThriftResultSetQueueFactory # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( + results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", @@ -249,44 +282,6 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result @@ -444,82 +439,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - result_data=None, - manifest=None, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - result_data: Result data from SEA response (optional) - manifest: Manifest from SEA response (optional) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError( - "_fill_results_buffer is not implemented for SEA backend" - ) - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 8997bda22..35c7bce4d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,12 +13,16 @@ import lz4.frame +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + try: import pyarrow except ImportError: pyarrow = None -from databricks.sql import OperationalError, exc +from databricks.sql import OperationalError +from databricks.sql.exc import ProgrammingError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( TRowSet, @@ -48,7 +52,7 @@ def remaining_rows(self): pass -class ResultSetQueueFactory(ABC): +class ThriftResultSetQueueFactory(ABC): @staticmethod def build_queue( row_set_type: TSparkRowSetType, @@ -57,7 +61,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +210,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index a3f9b1af8..5848d780b 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -196,10 +196,21 @@ def test_execute_async__long_running(self): assert result[0].asDict() == {"count(1)": 0} - def test_execute_async__small_result(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_execute_async__small_result(self, extra_params): small_result_query = "SELECT 1" - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute_async(small_result_query) ## Fake sleep for 5 secs @@ -328,8 +339,19 @@ def test_incorrect_query_throws_exception(self): cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") assert "table_234234234" in str(cm.value) - def test_create_table_will_return_empty_result_set(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_create_table_will_return_empty_result_set(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute( @@ -527,10 +549,21 @@ def test_get_catalogs(self): ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") - def test_get_arrow(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_get_arrow(self, extra_params): # These tests are quite light weight as the arrow fetch methods are used internally # by everything else - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT * FROM range(10)") table_1 = cursor.fetchmany_arrow(1).to_pydict() assert table_1 == OrderedDict([("id", [0])]) @@ -538,9 +571,20 @@ def test_get_arrow(self): table_2 = cursor.fetchall_arrow().to_pydict() assert table_2 == OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])]) - def test_unicode(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_unicode(self, extra_params): unicode_str = "数据砖" - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT '{}'".format(unicode_str)) results = cursor.fetchall() assert len(results) == 1 and len(results[0]) == 1 @@ -578,8 +622,19 @@ def execute_really_long_query(): assert len(cursor.fetchall()) == 3 @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_failure(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_can_execute_command_after_failure(self, extra_params): + with self.cursor(extra_params) as cursor: with pytest.raises(DatabaseError): cursor.execute("this is a sytnax error") @@ -589,8 +644,19 @@ def test_can_execute_command_after_failure(self): self.assertEqualRowValues(res, [[1]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_success(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_can_execute_command_after_success(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1;") cursor.execute("SELECT 2;") @@ -602,8 +668,19 @@ def generate_multi_row_query(self): return query @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchone(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchone(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -614,8 +691,19 @@ def test_fetchone(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchall(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchall(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -624,8 +712,19 @@ def test_fetchall(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_when_stride_fits(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchmany_when_stride_fits(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -633,8 +732,19 @@ def test_fetchmany_when_stride_fits(self): self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_in_excess(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_fetchmany_in_excess(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -642,8 +752,19 @@ def test_fetchmany_in_excess(self): self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_iterator_api(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_iterator_api(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -716,8 +837,21 @@ def test_timestamps_arrow(self): ), "timestamp {} did not match {}".format(timestamp, expected) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_multi_timestamps_arrow(self): - with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_multi_timestamps_arrow(self, extra_params): + with self.cursor( + {"session_configuration": {"ansi_mode": False}, **extra_params} + ) as cursor: query, expected = self.multi_query() expected = [ [self.maybe_add_timezone_to_timestamp(ts) for ts in row] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0eda7767c..5ffdea9f0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -100,6 +100,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): ) mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False + mock_execute_response.description = [] # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 975376e13..13dfac006 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -77,7 +77,7 @@ def test_filter_by_column_values(self): "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True ): with patch( - "databricks.sql.result_set.SeaResultSet" + "databricks.sql.backend.sea.result_set.SeaResultSet" ) as mock_sea_result_set_class: mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance @@ -104,7 +104,7 @@ def test_filter_by_column_values(self): "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True ): with patch( - "databricks.sql.result_set.SeaResultSet" + "databricks.sql.backend.sea.result_set.SeaResultSet" ) as mock_sea_result_set_class: mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc6768d2b..7eae8e5a8 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -196,7 +196,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i ) # Test close_session with invalid ID type - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) @@ -245,7 +245,7 @@ def test_command_execution_sync( assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: mock_thrift_handle = MagicMock() mock_thrift_handle.sessionId.guid = b"guid" mock_thrift_handle.sessionId.secret = b"secret" @@ -449,7 +449,7 @@ def test_command_management( ) # Test cancel_command with invalid ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.cancel_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -463,7 +463,7 @@ def test_command_management( ) # Test close_command with invalid ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.close_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -522,7 +522,7 @@ def test_command_management( assert result.status == CommandState.SUCCEEDED # Test get_execution_result with invalid ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -621,18 +621,6 @@ def test_utility_methods(self, sea_client): assert description[1][1] == "INT" # type_code assert description[1][6] is False # null_ok - # Test _extract_description_from_manifest with empty columns - empty_manifest = MagicMock() - empty_manifest.schema = {"columns": []} - assert sea_client._extract_description_from_manifest(empty_manifest) is None - - # Test _extract_description_from_manifest with no columns key - no_columns_manifest = MagicMock() - no_columns_manifest.schema = {} - assert ( - sea_client._extract_description_from_manifest(no_columns_manifest) is None - ) - def test_results_message_to_execute_response_is_staging_operation(self, sea_client): """Test that is_staging_operation is correctly set from manifest.is_volume_operation.""" # Test when is_volume_operation is True @@ -755,7 +743,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet mock_result_set = Mock(spec=SeaResultSet) diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py new file mode 100644 index 000000000..13970c5db --- /dev/null +++ b/tests/unit/test_sea_conversion.py @@ -0,0 +1,130 @@ +""" +Tests for the conversion module in the SEA backend. + +This module contains tests for the SqlType and SqlTypeConverter classes. +""" + +import pytest +import datetime +import decimal +from unittest.mock import Mock, patch + +from databricks.sql.backend.sea.utils.conversion import SqlType, SqlTypeConverter + + +class TestSqlTypeConverter: + """Test suite for the SqlTypeConverter class.""" + + def test_convert_numeric_types(self): + """Test converting numeric types.""" + # Test integer types + assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123 + assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456 + assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789 + assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890 + + # Test floating point types + assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45 + assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE) == 678.90 + + # Test decimal type + decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test decimal with precision and scale + decimal_value = SqlTypeConverter.convert_value( + "123.45", SqlType.DECIMAL, precision=5, scale=2 + ) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test invalid numeric input + result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT) + assert result == "not_a_number" # Returns original value on error + + def test_convert_boolean_type(self): + """Test converting boolean types.""" + # True values + assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN) is True + + # False values + assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN) is False + + def test_convert_datetime_types(self): + """Test converting datetime types.""" + # Test date type + date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE) + assert isinstance(date_value, datetime.date) + assert date_value == datetime.date(2023, 1, 15) + + # Test timestamp type + timestamp_value = SqlTypeConverter.convert_value( + "2023-01-15T12:30:45", SqlType.TIMESTAMP + ) + assert isinstance(timestamp_value, datetime.datetime) + assert timestamp_value.year == 2023 + assert timestamp_value.month == 1 + assert timestamp_value.day == 15 + assert timestamp_value.hour == 12 + assert timestamp_value.minute == 30 + assert timestamp_value.second == 45 + + # Test interval type (currently returns as string) + interval_value = SqlTypeConverter.convert_value( + "1 day 2 hours", SqlType.INTERVAL + ) + assert interval_value == "1 day 2 hours" + + # Test invalid date input + result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE) + assert result == "not_a_date" # Returns original value on error + + def test_convert_string_types(self): + """Test converting string types.""" + # String types don't need conversion, they should be returned as-is + assert ( + SqlTypeConverter.convert_value("test string", SqlType.STRING) + == "test string" + ) + assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char" + + def test_convert_binary_type(self): + """Test converting binary type.""" + # Test valid hex string + binary_value = SqlTypeConverter.convert_value("48656C6C6F", SqlType.BINARY) + assert isinstance(binary_value, bytes) + assert binary_value == b"Hello" + + # Test invalid binary input + result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY) + assert result == "not_hex" # Returns original value on error + + def test_convert_unsupported_type(self): + """Test converting an unsupported type.""" + # Should return the original value + assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test" + + # Complex types should return as-is + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.MAP) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT) + == "complex_value" + ) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py new file mode 100644 index 000000000..93d3dc4d7 --- /dev/null +++ b/tests/unit/test_sea_queue.py @@ -0,0 +1,182 @@ +""" +Tests for SEA-related queue classes in utils.py. + +This module contains tests for the JsonQueue and SeaResultSetQueueFactory classes. +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch + +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.constants import ResultFormat + + +class TestJsonQueue: + """Test suite for the JsonQueue class.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", 1, True], + ["value2", 2, False], + ["value3", 3, True], + ["value4", 4, False], + ["value5", 5, True], + ] + + def test_init(self, sample_data): + """Test initialization of JsonQueue.""" + queue = JsonQueue(sample_data) + assert queue.data_array == sample_data + assert queue.cur_row_index == 0 + assert queue.num_rows == len(sample_data) + + def test_next_n_rows_partial(self, sample_data): + """Test fetching a subset of rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(2) + assert result == sample_data[:2] + assert queue.cur_row_index == 2 + + def test_next_n_rows_all(self, sample_data): + """Test fetching all rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data)) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_more_than_available(self, sample_data): + """Test fetching more rows than available.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data) + 10) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_after_partial(self, sample_data): + """Test fetching rows after a partial fetch.""" + queue = JsonQueue(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.next_n_rows(2) # Fetch next 2 rows + assert result == sample_data[2:4] + assert queue.cur_row_index == 4 + + def test_remaining_rows_all(self, sample_data): + """Test fetching all remaining rows at once.""" + queue = JsonQueue(sample_data) + result = queue.remaining_rows() + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_remaining_rows_after_partial(self, sample_data): + """Test fetching remaining rows after a partial fetch.""" + queue = JsonQueue(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.remaining_rows() # Fetch remaining rows + assert result == sample_data[2:] + assert queue.cur_row_index == len(sample_data) + + def test_empty_data(self): + """Test with empty data array.""" + queue = JsonQueue([]) + assert queue.next_n_rows(10) == [] + assert queue.remaining_rows() == [] + assert queue.cur_row_index == 0 + assert queue.num_rows == 0 + + +class TestSeaResultSetQueueFactory: + """Test suite for the SeaResultSetQueueFactory class.""" + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @pytest.fixture + def mock_description(self): + """Create a mock column description.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + def _create_empty_manifest(self, format: ResultFormat): + return ResultManifest( + format=format.value, + schema={}, + total_row_count=-1, + total_byte_count=-1, + total_chunk_count=-1, + ) + + def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): + """Test building a queue with inline JSON data.""" + # Create sample data for inline JSON result + data = [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ] + + # Create a ResultData object with inline data + result_data = ResultData(data=data, external_links=None, row_count=len(data)) + + # Create a manifest (not used for inline data) + manifest = self._create_empty_manifest(ResultFormat.JSON_ARRAY) + + # Build the queue + queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) + + # Verify the queue is a JsonQueue with the correct data + assert isinstance(queue, JsonQueue) + assert queue.data_array == data + assert queue.num_rows == len(data) + + def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): + """Test building a queue with empty data.""" + # Create a ResultData object with no data + result_data = ResultData(data=[], external_links=None, row_count=0) + + # Build the queue + queue = SeaResultSetQueueFactory.build_queue( + result_data, + self._create_empty_manifest(ResultFormat.JSON_ARRAY), + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) + + # Verify the queue is a JsonQueue with empty data + assert isinstance(queue, JsonQueue) + assert queue.data_array == [] + assert queue.num_rows == 0 + + def test_build_queue_with_external_links(self, mock_sea_client, mock_description): + """Test building a queue with external links raises NotImplementedError.""" + # Create a ResultData object with external links + result_data = ResultData( + data=None, external_links=["link1", "link2"], row_count=10 + ) + + # Verify that NotImplementedError is raised + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + SeaResultSetQueueFactory.build_queue( + result_data, + self._create_empty_manifest(ResultFormat.ARROW_STREAM), + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c596dbc14..544edaf96 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -6,10 +6,13 @@ """ import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import Mock -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType +from databricks.sql.backend.sea.result_set import SeaResultSet, Row +from databricks.sql.backend.sea.queue import JsonQueue +from databricks.sql.backend.sea.utils.constants import ResultFormat +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest class TestSeaResultSet: @@ -37,11 +40,65 @@ def execute_response(self): mock_response.is_direct_results = False mock_response.results_queue = None mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), ] mock_response.is_staging_operation = False + mock_response.lz4_compressed = False + mock_response.arrow_schema_bytes = None return mock_response + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ["value3", "3", "true"], + ["value4", "4", "false"], + ["value5", "5", "true"], + ] + + def _create_empty_manifest(self, format: ResultFormat): + """Create an empty manifest.""" + return ResultManifest( + format=format.value, + schema={}, + total_row_count=-1, + total_byte_count=-1, + total_chunk_count=-1, + ) + + @pytest.fixture + def result_set_with_data( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Create a SeaResultSet with sample data.""" + # Create ResultData with inline data + result_data = ResultData( + data=sample_data, external_links=None, row_count=len(sample_data) + ) + + # Initialize SeaResultSet with result data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=result_data, + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = JsonQueue(sample_data) + + return result_set + + @pytest.fixture + def json_queue(self, sample_data): + """Create a JsonQueue with sample data.""" + return JsonQueue(sample_data) + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -50,6 +107,8 @@ def test_init_with_execute_response( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -69,6 +128,8 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -89,6 +150,8 @@ def test_close_when_already_closed_server_side( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -111,6 +174,8 @@ def test_close_when_connection_closed( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -123,79 +188,191 @@ def test_close_when_connection_closed( assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + def test_init_with_result_data(self, result_set_with_data, sample_data): + """Test initializing SeaResultSet with result data.""" + # Verify the results queue was created correctly + assert isinstance(result_set_with_data.results, JsonQueue) + assert result_set_with_data.results.data_array == sample_data + assert result_set_with_data.results.num_rows == len(sample_data) - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() + def test_convert_json_types(self, result_set_with_data, sample_data): + """Test the _convert_json_types method.""" + # Call _convert_json_types + converted_row = result_set_with_data._convert_json_types(sample_data[0]) - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) + # Verify the conversion + assert converted_row[0] == "value1" # string stays as string + assert converted_row[1] == 1 # "1" converted to int + assert converted_row[2] is True # "true" converted to boolean - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() + def test_create_json_table(self, result_set_with_data, sample_data): + """Test the _create_json_table method.""" + # Call _create_json_table + result_rows = result_set_with_data._create_json_table(sample_data) - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() + # Verify the result + assert len(result_rows) == len(sample_data) + assert isinstance(result_rows[0], Row) + assert result_rows[0].col1 == "value1" + assert result_rows[0].col2 == 1 + assert result_rows[0].col3 is True - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) + def test_fetchmany_json(self, result_set_with_data): + """Test the fetchmany_json method.""" + # Test fetching a subset of rows + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 2 + + # Test fetching the next subset + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 4 + + # Test fetching more than available + result = result_set_with_data.fetchmany_json(10) + assert len(result) == 1 # Only one row left + assert result_set_with_data._next_row_index == 5 + + def test_fetchall_json(self, result_set_with_data, sample_data): + """Test the fetchall_json method.""" + # Test fetching all rows + result = result_set_with_data.fetchall_json() + assert result == sample_data + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + result = result_set_with_data.fetchall_json() + assert result == [] + assert result_set_with_data._next_row_index == len(sample_data) + + def test_fetchone(self, result_set_with_data): + """Test the fetchone method.""" + # Test fetching one row at a time + row1 = result_set_with_data.fetchone() + assert isinstance(row1, Row) + assert row1.col1 == "value1" + assert row1.col2 == 1 + assert row1.col3 is True + assert result_set_with_data._next_row_index == 1 + + row2 = result_set_with_data.fetchone() + assert isinstance(row2, Row) + assert row2.col1 == "value2" + assert row2.col2 == 2 + assert row2.col3 is False + assert result_set_with_data._next_row_index == 2 + + # Fetch the rest + result_set_with_data.fetchall() + # Test fetching when no more rows + row_none = result_set_with_data.fetchone() + assert row_none is None + + def test_fetchmany(self, result_set_with_data): + """Test the fetchmany method.""" + # Test fetching multiple rows + rows = result_set_with_data.fetchmany(2) + assert len(rows) == 2 + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert rows[1].col1 == "value2" + assert rows[1].col2 == 2 + assert rows[1].col3 is False + assert result_set_with_data._next_row_index == 2 + + # Test with invalid size with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", + ValueError, match="size argument for fetchmany is -1 but must be >= 0" ): - result_set.fetchall_arrow() + result_set_with_data.fetchmany(-1) + + def test_fetchall(self, result_set_with_data, sample_data): + """Test the fetchall method.""" + # Test fetching all rows + rows = result_set_with_data.fetchall() + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + rows = result_set_with_data.fetchall() + assert len(rows) == 0 + + def test_iteration(self, result_set_with_data, sample_data): + """Test iterating over the result set.""" + # Test iteration + rows = list(result_set_with_data) + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + + def test_fetchmany_arrow_not_implemented( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" + # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=None, external_links=[]), + manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), + buffer_size_bytes=1000, + arraysize=100, + ) + def test_fetchall_arrow_not_implemented( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" + # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - # Test using the result set in a for loop - for row in result_set: - pass + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=None, external_links=[]), + manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), + buffer_size_bytes=1000, + arraysize=100, + ) - def test_fill_results_buffer_not_implemented( + def test_is_staging_operation( self, mock_connection, mock_sea_client, execute_response ): - """Test that _fill_results_buffer raises NotImplementedError.""" + """Test the is_staging_operation property.""" + # Set is_staging_operation to True + execute_response.is_staging_operation = True + + # Create a result set result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", - ): - result_set._fill_results_buffer() + # Test the property + assert result_set.is_staging_operation is True diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..4a4295e11 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -610,7 +610,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): self.assertIn("some information about the error", str(cm.exception)) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) def test_handle_execute_response_sets_compression_in_direct_results( self, build_queue @@ -998,7 +999,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( @@ -1043,7 +1045,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(is_direct_results, has_more_rows_result) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response(