From 138c2aebab99659d1c970fa70e4a431fec78aae2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:24:22 +0000 Subject: [PATCH 001/204] [squash from exec-sea] bring over execution phase changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 30 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 99 ++- src/databricks/sql/backend/types.py | 64 +- src/databricks/sql/client.py | 1 - src/databricks/sql/result_set.py | 234 ++++-- src/databricks/sql/session.py | 2 +- src/databricks/sql/utils.py | 7 - tests/unit/test_client.py | 22 +- tests/unit/test_fetches.py | 13 +- tests/unit/test_fetches_bench.py | 3 +- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + tests/unit/test_thrift_backend.py | 55 +- 22 files changed, 2375 insertions(+), 366 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -88,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d4..e03d6f235 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,11 +5,10 @@ import time import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( @@ -17,8 +16,9 @@ SessionId, CommandId, BackendType, + guid_to_hex_id, + ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -42,7 +42,7 @@ ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -797,23 +797,27 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id = CommandId.from_thrift_handle(resp.operationHandle) - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Invalid operation state: {operation_state}") + + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -863,15 +867,14 @@ def get_execution_result( ) execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), - has_been_closed_server_side=False, + command_id=command_id, + status=resp.status, + description=description, has_more_rows=has_more_rows, + results_queue=queue, + has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -881,6 +884,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -909,10 +913,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - state = CommandState.from_thrift_state(operation_state) - if state is None: - raise ValueError(f"Unknown command state: {operation_state}") - return state + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -947,8 +948,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -995,7 +994,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1004,6 +1005,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1013,8 +1015,6 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1027,7 +1027,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1036,6 +1038,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1047,8 +1050,6 @@ def get_schemas( catalog_name=None, schema_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1063,7 +1064,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1072,6 +1075,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1085,8 +1089,6 @@ def get_tables( table_name=None, table_types=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1103,7 +1105,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1112,6 +1116,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1125,8 +1130,6 @@ def get_columns( table_name=None, column_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1143,7 +1146,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1152,6 +1157,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1165,7 +1171,12 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + execute_response.command_id = command_id + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1226,7 +1237,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,28 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + + Args: + state: SEA state string + + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -285,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -318,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None @@ -394,3 +394,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a7..e145e4e58 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -24,7 +24,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d3579..fc8595839 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,26 +1,23 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING import logging import time import pandas -from databricks.sql.backend.types import CommandId, CommandState - try: import pyarrow except ImportError: pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection - from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -34,32 +31,31 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, + connection, + backend, arraysize: int, buffer_size_bytes: int, + command_id=None, + status=None, + has_been_closed_server_side: bool = False, + has_more_rows: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side + """Initialize the base ResultSet with common properties.""" self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self._has_more_rows = has_more_rows + self.results = results_queue + self._is_staging_operation = is_staging_operation def __iter__(self): while True: @@ -74,10 +70,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -101,12 +96,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -119,7 +114,7 @@ def close(self) -> None: """ try: if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -129,7 +124,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -138,11 +133,12 @@ class ThriftResultSet(ResultSet): def __init__( self, connection: "Connection", - execute_response: ExecuteResponse, + execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -154,37 +150,33 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.lz4_compressed = execute_response.lz4_compressed - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, @@ -196,7 +188,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -248,7 +240,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -280,7 +272,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -305,7 +297,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -320,7 +312,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -346,7 +338,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -389,24 +381,110 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod - def _get_schema_description(table_schema_message): +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection, + sea_client, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + execute_response=None, + sea_response=None, + ): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7c33d9b2d..76aec4675 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2622b1172..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -349,13 +349,6 @@ def _create_empty_table(self) -> "pyarrow.Table": return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1a7950870..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,7 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -121,10 +121,10 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Verify initial state self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( + expected_status = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) - self.assertEqual(real_result_set.op_state, expected_op_state) + self.assertEqual(real_result_set.status, expected_status) # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) @@ -146,8 +146,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # 1. has_been_closed_server_side should always be True after close() self.assertTrue(real_result_set.has_been_closed_server_side) - # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + # 2. status should always be CLOSED after close() + self.assertEqual(real_result_set.status, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: @@ -556,7 +556,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( @@ -678,10 +678,10 @@ def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) result_set.backend = Mock() - result_set.backend.CLOSED_OP_STATE = "CLOSED" + result_set.backend.CLOSED_OP_STATE = CommandState.CLOSED result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = "RUNNING" + result_set.status = CommandState.RUNNING result_set.has_been_closed_server_side = False result_set.command_id = Mock() @@ -695,7 +695,7 @@ def __init__(self): try: try: if ( - result_set.op_state != result_set.backend.CLOSED_OP_STATE + result_set.status != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): @@ -705,7 +705,7 @@ def __init__(self): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.backend.CLOSED_OP_STATE + result_set.status = result_set.backend.CLOSED_OP_STATE result_set.backend.close_command.assert_called_once_with( result_set.command_id @@ -713,7 +713,7 @@ def __init__(self): assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.status == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a64..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -42,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + results_queue=arrow_queue, is_staging_operation=False, ), thrift_client=None, @@ -88,6 +88,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, has_more_rows=True, @@ -96,9 +97,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00da..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e3..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,11 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -644,7 +651,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -878,11 +885,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) + self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -915,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -943,15 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -971,6 +987,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -1018,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1150,7 +1172,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1184,7 +1206,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1215,7 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1255,7 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1299,7 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1645,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,7 +2228,8 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class From 3e3ab94e8fa3dd02e4b05b5fc35939aef57793a2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:31:37 +0000 Subject: [PATCH 002/204] remove excess test Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +++----------------- 1 file changed, 14 insertions(+), 110 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() From 4a781653375d8f06dd7d9ad745446e49a355c680 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:33:02 +0000 Subject: [PATCH 003/204] add docstring Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..cd347d9ab 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,6 +86,33 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod From 0dac4aaf90dba50151dd7565adee270a794e8330 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:34:49 +0000 Subject: [PATCH 004/204] remvoe exec func in sea backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 360 +++------------------- 1 file changed, 35 insertions(+), 325 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 1b794c7df6f5e414ef793a5da0f2b8ba19c9bc61 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:35:40 +0000 Subject: [PATCH 005/204] remove excess files Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 -------------- tests/unit/test_result_set_filter.py | 246 ----------------------- tests/unit/test_sea_result_set.py | 275 -------------------------- 3 files changed, 664 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index f666fd613..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } - return mock_response - - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - execute_response=execute_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response - - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() From da5a6fe7511e927c511d61adb222b8a6a0da14d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:39:11 +0000 Subject: [PATCH 006/204] remove excess models Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/__init__.py | 30 ----- src/databricks/sql/backend/sea/models/base.py | 68 ----------- .../sql/backend/sea/models/requests.py | 110 +----------------- .../sql/backend/sea/models/responses.py | 95 +-------------- 4 files changed, 4 insertions(+), 299 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass From 686ade4fbf8e43a053b61f27220066852682167e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:40:50 +0000 Subject: [PATCH 007/204] remove excess sea backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 755 ++++----------------------------- 1 file changed, 94 insertions(+), 661 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 31e6c8305154e9c6384b422be35ac17b6f851e0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:54:05 +0000 Subject: [PATCH 008/204] cleanup Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 8 +- src/databricks/sql/backend/types.py | 38 ++++---- src/databricks/sql/result_set.py | 91 ++++++++------------ 3 files changed, 65 insertions(+), 72 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e03d6f235..21a6befbe 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -913,7 +913,10 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - return CommandState.from_thrift_state(operation_state) + state = CommandState.from_thrift_state(operation_state) + if state is None: + raise ValueError(f"Invalid operation state: {operation_state}") + return state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -1175,7 +1178,6 @@ def _handle_execute_response(self, resp, cursor): execute_response, arrow_schema_bytes, ) = self._results_message_to_execute_response(resp, final_operation_state) - execute_response.command_id = command_id return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): @@ -1237,7 +1239,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..3107083fb 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -285,9 +285,6 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, - operation_type: Optional[int] = None, - has_result_set: bool = False, - modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -296,17 +293,34 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) - operation_type: The operation type (only used for Thrift) - has_result_set: Whether the command has a result set - modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret - self.operation_type = operation_type - self.has_result_set = has_result_set - self.modified_row_count = modified_row_count + + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) @classmethod def from_thrift_handle(cls, operation_handle): @@ -329,9 +343,6 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, - operation_handle.operationType, - operation_handle.hasResultSet, - operation_handle.modifiedRowCount, ) @classmethod @@ -364,9 +375,6 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, - operationType=self.operation_type, - hasResultSet=self.has_result_set, - modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index fc8595839..12ee129cf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -5,6 +5,8 @@ import time import pandas +from databricks.sql.backend.sea.backend import SeaDatabricksClient + try: import pyarrow except ImportError: @@ -13,6 +15,7 @@ if TYPE_CHECKING: from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError @@ -31,21 +34,37 @@ class ResultSet(ABC): def __init__( self, - connection, - backend, + connection: "Connection", + backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, - command_id=None, - status=None, + command_id: CommandId, + status: CommandState, has_been_closed_server_side: bool = False, has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, ): - """Initialize the base ResultSet with common properties.""" + """ + A ResultSet manages the results of a single command. + + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation + """ + self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -240,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: @@ -387,12 +406,11 @@ class SeaResultSet(ResultSet): def __init__( self, - connection, - sea_client, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - execute_response=None, - sea_response=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -402,56 +420,21 @@ def __init__( sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) + execute_response: Response from the execute command """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 69ea23811e03705998baba569bcda259a0646de5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:56:09 +0000 Subject: [PATCH 009/204] re-introduce get_schema_desc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 1 - src/databricks/sql/result_set.py | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3107083fb..7a276c102 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -299,7 +299,6 @@ def __init__( self.guid = guid self.secret = secret - def __str__(self) -> str: """ Return a string representation of the CommandId. diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 12ee129cf..1fee995e5 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -59,12 +59,12 @@ def __init__( has_been_closed_server_side: Whether the command has been closed on the server has_more_rows: Whether the command has more rows results_queue: The results queue - description: column description of the results + description: column description of the results is_staging_operation: Whether the command is a staging operation """ self.connection = connection - self.backend = backend + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -400,6 +400,23 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) + @staticmethod + def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" From 66d75171991f9fcc98d541729a3127aea0d37a81 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 06:57:53 +0000 Subject: [PATCH 010/204] remove SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 72 -------------------------------- 1 file changed, 72 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 1fee995e5..eaabcc186 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -416,75 +416,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 71feef96b3a41889a5cd9313fc81910cebd7a084 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:01:22 +0000 Subject: [PATCH 011/204] clean imports and attributes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/databricks_client.py | 1 + src/databricks/sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/result_set.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index cd347d9ab..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -88,6 +88,7 @@ def execute_command( ) -> Union["ResultSet", None]: """ Executes a SQL command or query within the specified session. + This method sends a SQL command to the server for execution and handles the response. It can operate in both synchronous and asynchronous modes. diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index eaabcc186..a33fc977d 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation From ae9862f90e7cf0a4949d6b1c7e04fdbae222c2d8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:05:53 +0000 Subject: [PATCH 012/204] pass CommandId to ExecResp Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 7 ++++++- src/databricks/sql/result_set.py | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 21a6befbe..316cf24a0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -866,9 +867,13 @@ def get_execution_result( ssl_options=self._ssl_options, ) + status = CommandState.from_thrift_state(resp.status) + if status is None: + raise ValueError(f"Invalid operation state: {resp.status}") + execute_response = ExecuteResponse( command_id=command_id, - status=resp.status, + status=status, description=description, has_more_rows=has_more_rows, results_queue=queue, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a33fc977d..a0cb73732 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) From d8aa69e40438c33014e0d5afaec6a4175e64bea8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:08:04 +0000 Subject: [PATCH 013/204] remove changes in types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 57 +++++++++-------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 7a276c102..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -285,6 +262,9 @@ def __init__( backend_type: BackendType, guid: Any, secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, ): """ Initialize a CommandId. @@ -293,11 +273,17 @@ def __init__( backend_type: The type of backend (THRIFT or SEA) guid: The primary identifier for the command secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command """ self.backend_type = backend_type self.guid = guid self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count def __str__(self) -> str: """ @@ -332,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -342,6 +329,9 @@ def from_thrift_handle(cls, operation_handle): BackendType.THRIFT, guid_bytes, secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, ) @classmethod @@ -374,6 +364,9 @@ def to_thrift_handle(self): handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) return ttypes.TOperationHandle( operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, ) def to_sea_statement_id(self): @@ -401,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From db139bc1179bb7cab6ec6f283cdfa0646b04b01b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:09:35 +0000 Subject: [PATCH 014/204] add back essential types (ExecResponse, from_sea_state) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 ++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..958eaa289 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,27 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + + class BackendType(Enum): """ @@ -394,3 +416,18 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False \ No newline at end of file From b977b1210a5d39543b8a3734128ba820e597337f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:11:23 +0000 Subject: [PATCH 015/204] fix fetch types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 4 ++-- src/databricks/sql/result_set.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 958eaa289..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -102,7 +102,6 @@ def from_sea_state(cls, state: str) -> Optional["CommandState"]: return state_mapping.get(state, None) - class BackendType(Enum): """ Enum representing the type of backend @@ -417,6 +416,7 @@ def to_hex_guid(self) -> str: else: return str(self.guid) + @dataclass class ExecuteResponse: """Response from executing a SQL command.""" @@ -430,4 +430,4 @@ class ExecuteResponse: results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True - is_staging_operation: bool = False \ No newline at end of file + is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0cb73732..e177d495f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass From da615c0db8ba2037c106b533331cf1ca1c9f49f9 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:12:45 +0000 Subject: [PATCH 016/204] excess imports Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 0da04a6f1086998927a28759fc67da4e2c8c71c6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 07:15:59 +0000 Subject: [PATCH 017/204] reduce diff by maintaining logs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 316cf24a0..821559ad3 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -800,7 +800,7 @@ def _results_message_to_execute_response(self, resp, operation_state): status = CommandState.from_thrift_state(operation_state) if status is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return ( ExecuteResponse( From ea9d456ee9ca47434618a079698fa166b6c8a308 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:47:54 +0000 Subject: [PATCH 018/204] fix int test types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +--- tests/e2e/common/retry_test_mixins.py | 2 +- tests/e2e/test_driver.py | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 821559ad3..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -867,9 +867,7 @@ def get_execution_result( ssl_options=self._ssl_options, ) - status = CommandState.from_thrift_state(resp.status) - if status is None: - raise ValueError(f"Invalid operation state: {resp.status}") + status = self.get_query_state(command_id) execute_response = ExecuteResponse( command_id=command_id, diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b5d01a45d..dd509c062 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -326,7 +326,7 @@ def test_retry_abort_close_operation_on_404(self, caplog): with self.connection(extra_params={**self._retry_policy}) as conn: with conn.cursor() as curs: with patch( - "databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", + "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", new_callable=PropertyMock, return_value=False, ): diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 22897644f..8cfed7c28 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -933,12 +933,12 @@ def test_result_set_close(self): result_set = cursor.active_result_set assert result_set is not None - initial_op_state = result_set.op_state + initial_op_state = result_set.status result_set.close() - assert result_set.op_state == CommandState.CLOSED - assert result_set.op_state != initial_op_state + assert result_set.status == CommandState.CLOSED + assert result_set.status != initial_op_state # Closing the result set again should be a no-op and not raise exceptions result_set.close() From 8985c624bcdbb7e0abfa73b7a1a2dbad15b4e1ec Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 08:55:24 +0000 Subject: [PATCH 019/204] [squashed from exec-sea] init execution func Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 ++- .../sql/backend/databricks_client.py | 28 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 360 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 110 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/backend/types.py | 25 +- src/databricks/sql/result_set.py | 118 ++- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 275 +++++++ tests/unit/test_session.py | 5 + 15 files changed, 2166 insertions(+), 219 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..87b62efea 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..c7a4ed1b1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..1c519d931 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..810c2e7a1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1242,7 +1241,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,8 +85,10 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. + Args: state: SEA state string + Returns: CommandState: The corresponding CommandState enum value """ @@ -306,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -339,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e177d495f..a4beda629 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -403,16 +403,96 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..f666fd613 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From d9bcdbef396433e01b298fca9a27b1bce2b1414b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:13 +0000 Subject: [PATCH 020/204] remove irrelevant changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 124 +----- .../sql/backend/databricks_client.py | 30 ++ src/databricks/sql/backend/sea/backend.py | 360 ++---------------- .../sql/backend/sea/models/__init__.py | 30 -- src/databricks/sql/backend/sea/models/base.py | 68 ---- .../sql/backend/sea/models/requests.py | 110 +----- .../sql/backend/sea/models/responses.py | 95 +---- src/databricks/sql/backend/types.py | 64 ++-- 8 files changed, 107 insertions(+), 774 deletions(-) delete mode 100644 src/databricks/sql/backend/sea/models/base.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 87b62efea..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,122 +6,34 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) - -def test_sea_query_exec(): - """ - Test executing a query using the SEA backend with result compression. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with result compression enabled and disabled, - and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - sys.exit(1) - - try: - # Test with compression enabled - logger.info("Creating connection with LZ4 compression enabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, # Enable cloud fetch to use compression - enable_query_result_lz4_compression=True, # Enable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"backend type: {type(connection.session.backend)}") - - # Execute a simple query with compression enabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query with compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression enabled") - - # Test with compression disabled - logger.info("Creating connection with LZ4 compression disabled") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, # Enable cloud fetch - enable_query_result_lz4_compression=False, # Disable LZ4 compression - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query with compression disabled - cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) - logger.info("Executing query without compression: SELECT 1 as test_value") - cursor.execute("SELECT 1 as test_value") - logger.info("Query without compression executed successfully") - cursor.close() - connection.close() - logger.info("Successfully closed SEA session with compression disabled") - - except Exception as e: - logger.error(f"Error during SEA query execution test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA query execution test with compression completed successfully") - - def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -130,33 +42,25 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", # add custom user agent - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback - logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") - if __name__ == "__main__": - # Test session management test_sea_session() - - # Test query execution with compression - test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..20b059fa7 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,6 +16,8 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -86,6 +88,34 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c7a4ed1b1..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,34 +1,23 @@ import logging import re -import uuid -import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -66,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -288,222 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else "NONE" - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, - byte_limit=max_bytes if max_bytes > 0 else None, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, - ) - - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=cursor.connection, - sea_response=response_data, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -514,22 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -539,30 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -574,43 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -622,33 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index b7c8bd399..c9310d367 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,49 +4,19 @@ This package contains data models for SEA API requests and responses. """ -from databricks.sql.backend.sea.models.base import ( - ServiceError, - StatementStatus, - ExternalLink, - ResultData, - ColumnInfo, - ResultManifest, -) - from databricks.sql.backend.sea.models.requests import ( - StatementParameter, - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) __all__ = [ - # Base models - "ServiceError", - "StatementStatus", - "ExternalLink", - "ResultData", - "ColumnInfo", - "ResultManifest", # Request models - "StatementParameter", - "ExecuteStatementRequest", - "GetStatementRequest", - "CancelStatementRequest", - "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models - "ExecuteStatementResponse", - "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py deleted file mode 100644 index 671f7be13..000000000 --- a/src/databricks/sql/backend/sea/models/base.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Base models for the SEA (Statement Execution API) backend. - -These models define the common structures used in SEA API requests and responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState - - -@dataclass -class ServiceError: - """Error information returned by the SEA API.""" - - message: str - error_code: Optional[str] = None - - -@dataclass -class StatementStatus: - """Status information for a statement execution.""" - - state: CommandState - error: Optional[ServiceError] = None - sql_state: Optional[str] = None - - -@dataclass -class ExternalLink: - """External link information for result data.""" - - external_link: str - expiration: str - chunk_index: int - - -@dataclass -class ResultData: - """Result data from a statement execution.""" - - data: Optional[List[List[Any]]] = None - external_links: Optional[List[ExternalLink]] = None - - -@dataclass -class ColumnInfo: - """Information about a column in the result set.""" - - name: str - type_name: str - type_text: str - nullable: bool = True - precision: Optional[int] = None - scale: Optional[int] = None - ordinal_position: Optional[int] = None - - -@dataclass -class ResultManifest: - """Manifest information for a result set.""" - - schema: List[ColumnInfo] - total_row_count: int - total_byte_count: int - truncated: bool = False - chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 1c519d931..7966cb502 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,111 +1,5 @@ -""" -Request models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API requests. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - - -@dataclass -class StatementParameter: - """Parameter for a SQL statement.""" - - name: str - value: Optional[str] = None - type: Optional[str] = None - - -@dataclass -class ExecuteStatementRequest: - """Request to execute a SQL statement.""" - - warehouse_id: str - statement: str - session_id: str - disposition: str = "EXTERNAL_LINKS" - format: str = "JSON_ARRAY" - wait_timeout: str = "10s" - on_wait_timeout: str = "CONTINUE" - row_limit: Optional[int] = None - byte_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - result: Dict[str, Any] = { - "warehouse_id": self.warehouse_id, - "session_id": self.session_id, - "statement": self.statement, - "disposition": self.disposition, - "format": self.format, - "wait_timeout": self.wait_timeout, - "on_wait_timeout": self.on_wait_timeout, - } - - if self.row_limit is not None and self.row_limit > 0: - result["row_limit"] = self.row_limit - - if self.byte_limit is not None and self.byte_limit > 0: - result["byte_limit"] = self.byte_limit - - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - - if self.result_compression: - result["result_compression"] = self.result_compression - - if self.parameters: - result["parameters"] = [ - { - "name": param.name, - **({"value": param.value} if param.value is not None else {}), - **({"type": param.type} if param.type is not None else {}), - } - for param in self.parameters - ] - - return result - - -@dataclass -class GetStatementRequest: - """Request to get information about a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CancelStatementRequest: - """Request to cancel a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} - - -@dataclass -class CloseStatementRequest: - """Request to close a statement.""" - - statement_id: str - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for JSON serialization.""" - return {"statement_id": self.statement_id} +from typing import Dict, Any, Optional +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..1bb54590f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,96 +1,5 @@ -""" -Response models for the SEA (Statement Execution API) backend. - -These models define the structures used in SEA API responses. -""" - -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field - -from databricks.sql.backend.types import CommandState -from databricks.sql.backend.sea.models.base import ( - StatementStatus, - ResultManifest, - ResultData, - ServiceError, -) - - -@dataclass -class ExecuteStatementResponse: - """Response from executing a SQL statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": - """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) - - -@dataclass -class GetStatementResponse: - """Response from getting information about a statement.""" - - statement_id: str - status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": - """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - return cls( - statement_id=data.get("statement_id", ""), - status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed - ) +from typing import Dict, Any +from dataclasses import dataclass @dataclass diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..9cd21b5e6 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, Optional, Any import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -81,28 +80,6 @@ def from_thrift_state( else: return None - @classmethod - def from_sea_state(cls, state: str) -> Optional["CommandState"]: - """ - Map SEA state string to CommandState enum. - - Args: - state: SEA state string - - Returns: - CommandState: The corresponding CommandState enum value - """ - state_mapping = { - "PENDING": cls.PENDING, - "RUNNING": cls.RUNNING, - "SUCCEEDED": cls.SUCCEEDED, - "FAILED": cls.FAILED, - "CLOSED": cls.CLOSED, - "CANCELED": cls.CANCELLED, - } - - return state_mapping.get(state, None) - class BackendType(Enum): """ @@ -308,6 +285,28 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -319,6 +318,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None @@ -394,19 +394,3 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) - - -@dataclass -class ExecuteResponse: - """Response from executing a SQL command.""" - - command_id: CommandId - status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None - has_been_closed_server_side: bool = False - lz4_compressed: bool = True - is_staging_operation: bool = False From ee9fa1c972bad75557ac0671d5eef96c0a0cff21 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:03:59 +0000 Subject: [PATCH 021/204] remove ResultSetFilter functionality Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 143 --------------- tests/unit/test_result_set_filter.py | 246 -------------------------- 2 files changed, 389 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 7f48b6179..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Callable, - TYPE_CHECKING, -) - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet - -from databricks.sql.result_set import SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data - return SeaResultSet( - connection=result_set.connection, - sea_response=filtered_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is typically in the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 From 24c6152e9c2c003aa3074057c3d7d6e98d8d1916 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:06:23 +0000 Subject: [PATCH 022/204] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 39 +- tests/unit/test_sea_backend.py | 755 ++++------------------------ 2 files changed, 132 insertions(+), 662 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e6..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,26 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + Args: + state: SEA state string + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -394,3 +415,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,650 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 From 67fd1012f9496724aa05183f82d9c92f0c40f1ed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:10:48 +0000 Subject: [PATCH 023/204] remove more irrelevant changes Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 2 - src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/result_set.py | 91 +++++++++---------- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa7..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 810c2e7a1..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a4beda629..dd61408db 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -402,6 +402,33 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): + """ + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + """ + + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ + + return [ + (column.name, map_col_type(column.datatype), None, None, None, None, None) + for column in table_schema_message.columns + ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for the SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -413,53 +440,19 @@ def _get_schema_description(table_schema_message): execute_response: Response from the execute command (new style) sea_response: Direct SEA response (legacy style) """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, ) def _fill_results_buffer(self): From 271fcafbb04e7c5e08423b7536dac57f9595c5b6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:12:13 +0000 Subject: [PATCH 024/204] even more irrelevant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- tests/unit/test_session.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dd61408db..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From bf26ea3e4dae441d0e82d1f55c3da36ee2282568 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 09:19:46 +0000 Subject: [PATCH 025/204] remove sea response as init option Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 103 ++++-------------------------- 1 file changed, 14 insertions(+), 89 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index f666fd613..02421a915 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -27,38 +27,6 @@ def mock_sea_client(self): """Create a mock SEA client.""" return Mock() - @pytest.fixture - def sea_response(self): - """Create a sample SEA response.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - @pytest.fixture def execute_response(self): """Create a sample execute response.""" @@ -72,78 +40,35 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } return mock_response - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, - sea_client=mock_sea_client, execute_response=execute_response, + sea_client=mock_sea_client, buffer_size_bytes=1000, arraysize=100, ) # Verify basic properties - assert result_set.statement_id == "test-statement-123" + assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA assert result_set.connection == mock_connection assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 - assert result_set._response == execute_response.sea_response + assert result_set.description == execute_response.description - def test_init_with_no_response(self, mock_connection, mock_sea_client): - """Test that initialization fails when neither response type is provided.""" - with pytest.raises(ValueError) as excinfo: - SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - assert "Either execute_response or sea_response must be provided" in str( - excinfo.value - ) - - def test_close(self, mock_connection, mock_sea_client, sea_response): + def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -157,13 +82,13 @@ def test_close(self, mock_connection, mock_sea_client, sea_response): assert result_set.status == CommandState.CLOSED def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set that has already been closed server-side.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -178,14 +103,14 @@ def test_close_when_already_closed_server_side( assert result_set.status == CommandState.CLOSED def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -199,13 +124,13 @@ def test_close_when_connection_closed( assert result_set.status == CommandState.CLOSED def test_unimplemented_methods( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that unimplemented methods raise NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -258,13 +183,13 @@ def test_unimplemented_methods( pass def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, sea_response + self, mock_connection, mock_sea_client, execute_response ): """Test that _fill_results_buffer raises NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, + execute_response=execute_response, sea_client=mock_sea_client, - sea_response=sea_response, buffer_size_bytes=1000, arraysize=100, ) @@ -272,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() + result_set._fill_results_buffer() \ No newline at end of file From ed7cf9138e937774546fa0f3e793a6eb8768060a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:06:36 +0000 Subject: [PATCH 026/204] exec test example scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 147 ++++++++++------ examples/experimental/tests/__init__.py | 1 + .../tests/test_sea_async_query.py | 165 ++++++++++++++++++ .../experimental/tests/test_sea_metadata.py | 91 ++++++++++ .../experimental/tests/test_sea_session.py | 70 ++++++++ .../experimental/tests/test_sea_sync_query.py | 143 +++++++++++++++ 6 files changed, 566 insertions(+), 51 deletions(-) create mode 100644 examples/experimental/tests/__init__.py create mode 100644 examples/experimental/tests/test_sea_async_query.py create mode 100644 examples/experimental/tests/test_sea_metadata.py create mode 100644 examples/experimental/tests/test_sea_session.py create mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..33b5af334 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,66 +1,111 @@ +""" +Main script to run all SEA connector tests. + +This script imports and runs all the individual test modules and displays +a summary of test results with visual indicators. +""" import os import sys import logging -from databricks.sql.client import Connection +import importlib.util +from typing import Dict, Callable, List, Tuple -logging.basicConfig(level=logging.DEBUG) +# Configure logging +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. +# Define test modules and their main test functions +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", +] + +def load_test_function(module_name: str) -> Callable: + """Load a test function from a module.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "tests", + f"{module_name}.py" + ) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get the main test function (assuming it starts with "test_") + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + # For sync and async query modules, we want the main function that runs both tests + if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": + return getattr(module, name) - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. + # Fallback to the first test function found + for name in dir(module): + if name.startswith("test_") and callable(getattr(module, name)): + return getattr(module, name) - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ + raise ValueError(f"No test function found in module {module_name}") - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) + for module_name in TEST_MODULES: + try: + test_func = load_test_function(module_name) + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = test_func() + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + results.append((module_name, False)) - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") + return results + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent - ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) + passed = sum(1 for _, success in results if success) + total = len(results) - logger.info("SEA session test completed successfully") + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") if __name__ == "__main__": - test_sea_session() + # Check if required environment variables are set + required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error("Please set these variables before running the tests.") + sys.exit(1) + + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) \ No newline at end of file diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..5e1a8a58b --- /dev/null +++ b/examples/experimental/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..a4f3702f9 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,165 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch enabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute_async("SELECT 1 as test_value") + logger.info("Asynchronous query submitted successfully with cloud fetch disabled") + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..ba760b61a --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,91 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") + cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..c0f6817da --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,70 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..4879e587a --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,143 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) \ No newline at end of file From dae15e37b6161740481084c405aeff84278c73cd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:10:23 +0000 Subject: [PATCH 027/204] formatting (black) Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 53 ++++++++------ examples/experimental/tests/__init__.py | 1 - .../tests/test_sea_async_query.py | 72 +++++++++++++------ .../experimental/tests/test_sea_metadata.py | 27 ++++--- .../experimental/tests/test_sea_session.py | 5 +- .../experimental/tests/test_sea_sync_query.py | 48 +++++++++---- 6 files changed, 133 insertions(+), 73 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 33b5af334..b03f8ff64 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,90 +22,99 @@ "test_sea_metadata", ] + def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "tests", - f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") + def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback + logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results + def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") + if __name__ == "__main__": # Check if required environment variables are set - required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) \ No newline at end of file + sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index 5e1a8a58b..e69de29bb 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -1 +0,0 @@ -# This file makes the tests directory a Python package \ No newline at end of file diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a4f3702f9..a776377c3 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -33,7 +33,9 @@ def test_sea_async_query_with_cloud_fetch(): try: # Create connection with cloud fetch enabled - logger.info("Creating connection for asynchronous query execution with cloud fetch enabled") + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -51,30 +53,39 @@ def test_sea_async_query_with_cloud_fetch(): # Execute a simple query asynchronously cursor = connection.cursor() - logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + ) cursor.execute_async("SELECT 1 as test_value") - logger.info("Asynchronous query submitted successfully with cloud fetch enabled") - + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + # Check query state logger.info("Checking query state...") while cursor.is_query_pending(): logger.info("Query is still pending, waiting...") time.sleep(1) - + logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled") - + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch enabled" + ) + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}") + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -100,7 +111,9 @@ def test_sea_async_query_without_cloud_fetch(): try: # Create connection with cloud fetch disabled - logger.info("Creating connection for asynchronous query execution with cloud fetch disabled") + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -119,30 +132,39 @@ def test_sea_async_query_without_cloud_fetch(): # Execute a simple query asynchronously cursor = connection.cursor() - logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + ) cursor.execute_async("SELECT 1 as test_value") - logger.info("Asynchronous query submitted successfully with cloud fetch disabled") - + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + # Check query state logger.info("Checking query state...") while cursor.is_query_pending(): logger.info("Query is still pending, waiting...") time.sleep(1) - + logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled") - + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch disabled" + ) + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}") + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -152,14 +174,18 @@ def test_sea_async_query_exec(): Run both asynchronous query tests and return overall success. """ with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + return with_cloud_fetch_success and without_cloud_fetch_success if __name__ == "__main__": success = test_sea_async_query_exec() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index ba760b61a..c715e5984 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -28,9 +28,11 @@ def test_sea_metadata(): "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." ) return False - + if not catalog: - logger.error("DATABRICKS_CATALOG environment variable is required for metadata tests.") + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) return False try: @@ -55,37 +57,42 @@ def test_sea_metadata(): logger.info("Fetching catalogs...") cursor.catalogs() logger.info("Successfully fetched catalogs") - + # Test schemas logger.info(f"Fetching schemas for catalog '{catalog}'...") cursor.schemas(catalog_name=catalog) logger.info("Successfully fetched schemas") - + # Test tables logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") cursor.tables(catalog_name=catalog, schema_name="default") logger.info("Successfully fetched tables") - + # Test columns for a specific table # Using a common table that should exist in most environments - logger.info(f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'...") - cursor.columns(catalog_name=catalog, schema_name="default", table_name="information_schema") + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="information_schema" + ) logger.info("Successfully fetched columns") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: logger.error(f"Error during SEA metadata test: {str(e)}") import traceback + logger.error(traceback.format_exc()) return False if __name__ == "__main__": success = test_sea_metadata() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py index c0f6817da..516c1bbb8 100644 --- a/examples/experimental/tests/test_sea_session.py +++ b/examples/experimental/tests/test_sea_session.py @@ -55,16 +55,17 @@ def test_sea_session(): logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) return False if __name__ == "__main__": success = test_sea_session() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 4879e587a..07be8aafc 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -31,7 +31,9 @@ def test_sea_sync_query_with_cloud_fetch(): try: # Create connection with cloud fetch enabled - logger.info("Creating connection for synchronous query execution with cloud fetch enabled") + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -49,20 +51,25 @@ def test_sea_sync_query_with_cloud_fetch(): # Execute a simple query cursor = connection.cursor() - logger.info("Executing synchronous query with cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + ) cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch enabled") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}") + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -88,7 +95,9 @@ def test_sea_sync_query_without_cloud_fetch(): try: # Create connection with cloud fetch disabled - logger.info("Creating connection for synchronous query execution with cloud fetch disabled") + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) connection = Connection( server_hostname=server_hostname, http_path=http_path, @@ -107,20 +116,25 @@ def test_sea_sync_query_without_cloud_fetch(): # Execute a simple query cursor = connection.cursor() - logger.info("Executing synchronous query without cloud fetch: SELECT 1 as test_value") + logger.info( + "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + ) cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch disabled") - + # Close resources cursor.close() connection.close() logger.info("Successfully closed SEA session") - + return True except Exception as e: - logger.error(f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}") + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) import traceback + logger.error(traceback.format_exc()) return False @@ -130,14 +144,18 @@ def test_sea_sync_query_exec(): Run both synchronous query tests and return overall success. """ with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info(f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info(f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}") - + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + return with_cloud_fetch_success and without_cloud_fetch_success if __name__ == "__main__": success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) From db5bbea88eabcde2d0b86811391297baf8471c70 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 10:35:08 +0000 Subject: [PATCH 028/204] [squashed from sea-exec] merge sea stuffs Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 51 +- examples/experimental/tests/__init__.py | 1 + .../sql/backend/databricks_client.py | 28 - src/databricks/sql/backend/filters.py | 143 ++++ src/databricks/sql/backend/sea/backend.py | 359 ++++++++- .../sql/backend/sea/models/__init__.py | 30 + src/databricks/sql/backend/sea/models/base.py | 68 ++ .../sql/backend/sea/models/requests.py | 106 ++- .../sql/backend/sea/models/responses.py | 95 ++- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/backend/types.py | 25 +- src/databricks/sql/result_set.py | 92 ++- tests/unit/test_result_set_filter.py | 246 ++++++ tests/unit/test_sea_backend.py | 755 +++++++++++++++--- tests/unit/test_sea_result_set.py | 30 +- tests/unit/test_session.py | 5 + 16 files changed, 1805 insertions(+), 232 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 src/databricks/sql/backend/sea/models/base.py create mode 100644 tests/unit/test_result_set_filter.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..128bc1aa1 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,99 +22,90 @@ "test_sea_metadata", ] - def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), + "tests", + f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") - def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback - logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results - def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") - if __name__ == "__main__": # Check if required environment variables are set - required_vars = [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_HTTP_PATH", - "DATABRICKS_TOKEN", - ] + required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error( - f"Missing required environment variables: {', '.join(missing_vars)}" - ) + logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index e69de29bb..5e1a8a58b 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package \ No newline at end of file diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..bbca4c502 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..7f48b6179 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..10100e86e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,221 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else None + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +513,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +538,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +573,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +621,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d367..b7c8bd399 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 000000000..671f7be13 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb502..e26b32e0a 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,107 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590f..d70459b9f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..810c2e7a1 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1242,7 +1241,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..5bf02e0ea 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,8 +85,10 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. + Args: state: SEA state string + Returns: CommandState: The corresponding CommandState enum value """ @@ -306,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -339,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..2d4f3f346 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -403,14 +403,76 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) + + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") return [ (column.name, map_col_type(column.datatype), None, None, None, None, None) diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 000000000..e8eb2a757 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..c2078f731 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 02421a915..072b597a8 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -40,8 +40,36 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } return mock_response + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -197,4 +225,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() \ No newline at end of file + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8fd..fef070362 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From d5d3699cea5c5e67a48c5e789ebdd66964f1e975 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:44:58 +0000 Subject: [PATCH 029/204] remove excess changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 51 ++++++++++++--------- examples/experimental/tests/__init__.py | 1 - 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 128bc1aa1..b03f8ff64 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -22,90 +22,99 @@ "test_sea_metadata", ] + def load_test_function(module_name: str) -> Callable: """Load a test function from a module.""" module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "tests", - f"{module_name}.py" + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Get the main test function (assuming it starts with "test_") for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): # For sync and async query modules, we want the main function that runs both tests if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": return getattr(module, name) - + # Fallback to the first test function found for name in dir(module): if name.startswith("test_") and callable(getattr(module, name)): return getattr(module, name) - + raise ValueError(f"No test function found in module {module_name}") + def run_tests() -> List[Tuple[str, bool]]: """Run all tests and return results.""" results = [] - + for module_name in TEST_MODULES: try: test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - + success = test_func() results.append((module_name, success)) - + status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"Test {module_name}: {status}") - + except Exception as e: logger.error(f"Error loading or running test {module_name}: {str(e)}") import traceback + logger.error(traceback.format_exc()) results.append((module_name, False)) - + return results + def print_summary(results: List[Tuple[str, bool]]) -> None: """Print a summary of test results.""" logger.info(f"\n{'=' * 50}") logger.info("TEST SUMMARY") logger.info(f"{'-' * 50}") - + passed = sum(1 for _, success in results if success) total = len(results) - + for module_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" logger.info(f"{status} - {module_name}") - + logger.info(f"{'-' * 50}") logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") + if __name__ == "__main__": # Check if required environment variables are set - required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"] + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] missing_vars = [var for var in required_vars if not os.environ.get(var)] - + if missing_vars: - logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) logger.error("Please set these variables before running the tests.") sys.exit(1) - + # Run all tests results = run_tests() - + # Print summary print_summary(results) - + # Exit with appropriate status code all_passed = all(success for _, success in results) sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py index 5e1a8a58b..e69de29bb 100644 --- a/examples/experimental/tests/__init__.py +++ b/examples/experimental/tests/__init__.py @@ -1 +0,0 @@ -# This file makes the tests directory a Python package \ No newline at end of file From 6137a3dca8ea8d0c2105a175b99f45e77fa25f5b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:47:07 +0000 Subject: [PATCH 030/204] remove excess removed docstring Signed-off-by: varun-edachali-dbx --- .../sql/backend/databricks_client.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index bbca4c502..8fda71e1e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -86,6 +86,34 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns a ResultSet object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ pass @abstractmethod From 75b077320c196104e47af149b379ebc4e95463e3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:48:33 +0000 Subject: [PATCH 031/204] remove excess changes in backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 3 ++- src/databricks/sql/backend/types.py | 25 ++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 810c2e7a1..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ + # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(command_id.guid)) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 5bf02e0ea..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -85,10 +85,8 @@ def from_thrift_state( def from_sea_state(cls, state: str) -> Optional["CommandState"]: """ Map SEA state string to CommandState enum. - Args: state: SEA state string - Returns: CommandState: The corresponding CommandState enum value """ @@ -308,6 +306,28 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -319,6 +339,7 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ + if operation_handle is None: return None From 4494dcda4a503e6138e5761bc6155114d840be86 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:50:56 +0000 Subject: [PATCH 032/204] remove excess imports Signed-off-by: varun-edachali-dbx --- tests/unit/test_session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fef070362..92de8d8fd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,11 +2,6 @@ from unittest.mock import patch, Mock import gc -from databricks.sql.thrift_api.TCLIService.ttypes import ( - TOpenSessionResp, - TSessionHandle, - THandleIdentifier, -) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql From 4d0aeca0a2e9d887274cbdbd19c6f471f1a381a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:53:52 +0000 Subject: [PATCH 033/204] remove accidentally removed _get_schema_desc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 74 +++----------------------------- 1 file changed, 6 insertions(+), 68 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 2d4f3f346..e0b0289e6 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -403,76 +403,14 @@ def fetchmany(self, size: int) -> List[Row]: @staticmethod def _get_schema_description(table_schema_message): """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - # Handle both initialization styles - if execute_response is not None: - # New style with ExecuteResponse - command_id = execute_response.command_id - status = execute_response.status - has_been_closed_server_side = execute_response.has_been_closed_server_side - has_more_rows = execute_response.has_more_rows - results_queue = execute_response.results_queue - description = execute_response.description - is_staging_operation = execute_response.is_staging_operation - self._response = getattr(execute_response, "sea_response", {}) - self.statement_id = command_id.to_sea_statement_id() if command_id else None - elif sea_response is not None: - # Legacy style with direct sea_response - self._response = sea_response - # Extract values from sea_response - command_id = CommandId.from_sea_statement_id( - sea_response.get("statement_id", "") - ) - self.statement_id = sea_response.get("statement_id", "") - - # Extract status - status_data = sea_response.get("status", {}) - status = CommandState.from_sea_state(status_data.get("state", "PENDING")) - - # Set defaults for other fields - has_been_closed_server_side = False - has_more_rows = False - results_queue = None - description = None - is_staging_operation = False - else: - raise ValueError("Either execute_response or sea_response must be provided") - - # Call parent constructor with common attributes - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=command_id, - status=status, - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - results_queue=results_queue, - description=description, - is_staging_operation=is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. + Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 """ - raise NotImplementedError("fetchone is not implemented for SEA backend") + def map_col_type(type_): + if type_.startswith("decimal"): + return "decimal" + else: + return type_ return [ (column.name, map_col_type(column.datatype), None, None, None, None, None) From 7cece5e0870cd31943e72c86888d98ed4e09c17c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:56:24 +0000 Subject: [PATCH 034/204] remove unnecessary init with sea_response tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 072b597a8..02421a915 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -40,36 +40,8 @@ def execute_response(self): ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "result": {"data_array": [["1"]]}, - } return mock_response - def test_init_with_sea_response( - self, mock_connection, mock_sea_client, sea_response - ): - """Test initializing SeaResultSet with a SEA response.""" - result_set = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.statement_id == "test-statement-123" - assert result_set.status == CommandState.SUCCEEDED - assert result_set.command_id.guid == "test-statement-123" - assert result_set.command_id.backend_type == BackendType.SEA - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set._response == sea_response - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -225,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() + result_set._fill_results_buffer() \ No newline at end of file From 8977c06a27a68ae7c144a482e32c7bee1e18eaa3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 12:57:58 +0000 Subject: [PATCH 035/204] rmeove unnecessary changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e0b0289e6..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -64,7 +64,7 @@ def __init__( """ self.connection = connection - self.backend = backend # Store the backend client directly + self.backend = backend self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 @@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> Any: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> Any: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass @@ -207,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self._has_more_rows = has_more_rows + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -259,7 +259,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2): + def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result :param result1: @@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self._has_more_rows + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self._has_more_rows: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) From 0216d7ac6de96ece431f8bdd0d31c0acb1c28324 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:07:04 +0000 Subject: [PATCH 036/204] formatting (black) Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 02421a915..b691872af 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -197,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() \ No newline at end of file + result_set._fill_results_buffer() From d97463b45fd6c8e7457988441edc012e51d78368 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:21:34 +0000 Subject: [PATCH 037/204] move guid_to_hex_id import to utils Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..f90d2897e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -10,15 +10,13 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, - guid_to_hex_id, ExecuteResponse, ) +from databricks.sql.backend.utils.guid_utils import guid_to_hex_id try: import pyarrow From 139e2466ef9c35a2673e4af6066549004cf16533 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:22:25 +0000 Subject: [PATCH 038/204] reduce diff in guid utils import Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f90d2897e..4b3e827f2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -16,7 +16,8 @@ CommandId, ExecuteResponse, ) -from databricks.sql.backend.utils.guid_utils import guid_to_hex_id +from databricks.sql.backend.utils import guid_to_hex_id + try: import pyarrow From 4cb15fdaa8318b046f2ac082edb10679e7c7a501 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 13:47:34 +0000 Subject: [PATCH 039/204] improved models and filters from cloudfetch-sea branch Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 61 +++++--- src/databricks/sql/backend/sea/models/base.py | 13 +- .../sql/backend/sea/models/requests.py | 16 +- .../sql/backend/sea/models/responses.py | 146 ++++++++++++++++-- 4 files changed, 187 insertions(+), 49 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 7f48b6179..32fa78be4 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -9,14 +9,20 @@ List, Optional, Any, + Dict, Callable, + TypeVar, + Generic, + cast, TYPE_CHECKING, ) -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse, CommandId +from databricks.sql.backend.sea.models.base import ResultData -from databricks.sql.result_set import SeaResultSet +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet, SeaResultSet logger = logging.getLogger(__name__) @@ -43,26 +49,35 @@ def _filter_sea_result_set( Returns: A filtered SEA result set """ - # Create a filtered version of the result set - filtered_response = result_set._response.copy() - - # If there's a result with rows, filter them - if ( - "result" in filtered_response - and "data_array" in filtered_response["result"] - ): - rows = filtered_response["result"]["data_array"] - filtered_rows = [row for row in rows if filter_func(row)] - filtered_response["result"]["data_array"] = filtered_rows - - # Update row count if present - if "row_count" in filtered_response["result"]: - filtered_response["result"]["row_count"] = len(filtered_rows) - - # Create a new result set with the filtered data + # Get all remaining rows + original_index = result_set.results.cur_row_index + result_set.results.cur_row_index = 0 # Reset to beginning + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Import SeaResultSet here to avoid circular imports + from databricks.sql.result_set import SeaResultSet + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse with the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_more_rows=result_set._has_more_rows, + results_queue=JsonQueue(filtered_rows), + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=False, + is_staging_operation=False, + ) + return SeaResultSet( connection=result_set.connection, - sea_response=filtered_response, + execute_response=execute_response, sea_client=result_set.backend, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, @@ -92,6 +107,8 @@ def filter_by_column_values( allowed_values = [v.upper() for v in allowed_values] # Determine the type of result set and apply appropriate filtering + from databricks.sql.result_set import SeaResultSet + if isinstance(result_set, SeaResultSet): return ResultSetFilter._filter_sea_result_set( result_set, @@ -137,7 +154,7 @@ def filter_tables_by_type( table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES ) - # Table type is typically in the 6th column (index 5) + # Table type is the 6th column (index 5) return ResultSetFilter.filter_by_column_values( result_set, 5, valid_types, case_sensitive=False ) diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index 671f7be13..6175b4ca0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -34,6 +34,12 @@ class ExternalLink: external_link: str expiration: str chunk_index: int + byte_count: int = 0 + row_count: int = 0 + row_offset: int = 0 + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + http_headers: Optional[Dict[str, str]] = None @dataclass @@ -61,8 +67,11 @@ class ColumnInfo: class ResultManifest: """Manifest information for a result set.""" - schema: List[ColumnInfo] + format: str + schema: Dict[str, Any] # Will contain column information total_row_count: int total_byte_count: int + total_chunk_count: int truncated: bool = False - chunk_count: Optional[int] = None + chunks: Optional[List[Dict[str, Any]]] = None + result_compression: Optional[str] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index e26b32e0a..58921d793 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -21,18 +21,16 @@ class StatementParameter: class ExecuteStatementRequest: """Request to execute a SQL statement.""" - warehouse_id: str - statement: str session_id: str + statement: str + warehouse_id: str disposition: str = "EXTERNAL_LINKS" format: str = "JSON_ARRAY" + result_compression: Optional[str] = None + parameters: Optional[List[StatementParameter]] = None wait_timeout: str = "10s" on_wait_timeout: str = "CONTINUE" row_limit: Optional[int] = None - parameters: Optional[List[StatementParameter]] = None - catalog: Optional[str] = None - schema: Optional[str] = None - result_compression: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert the request to a dictionary for JSON serialization.""" @@ -49,12 +47,6 @@ def to_dict(self) -> Dict[str, Any]: if self.row_limit is not None and self.row_limit > 0: result["row_limit"] = self.row_limit - if self.catalog: - result["catalog"] = self.catalog - - if self.schema: - result["schema"] = self.schema - if self.result_compression: result["result_compression"] = self.result_compression diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d70459b9f..6b5067506 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -13,6 +13,8 @@ ResultManifest, ResultData, ServiceError, + ExternalLink, + ColumnInfo, ) @@ -37,20 +39,62 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": error_code=error_data.get("error_code"), ) - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") status = StatementStatus( - state=state, + state=CommandState.from_sea_state(status_data.get("state", "")), error=error, sql_state=status_data.get("sql_state"), ) + # Parse manifest + manifest = None + if "manifest" in data: + manifest_data = data["manifest"] + manifest = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + # Parse result data + result = None + if "result" in data: + result_data = data["result"] + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers"), + ) + ) + + result = ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + return cls( statement_id=data.get("statement_id", ""), status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed + manifest=manifest, + result=result, ) @@ -75,21 +119,62 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": error_code=error_data.get("error_code"), ) - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - status = StatementStatus( - state=state, + state=CommandState.from_sea_state(status_data.get("state", "")), error=error, sql_state=status_data.get("sql_state"), ) + # Parse manifest + manifest = None + if "manifest" in data: + manifest_data = data["manifest"] + manifest = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + # Parse result data + result = None + if "result" in data: + result_data = data["result"] + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers"), + ) + ) + + result = ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + return cls( statement_id=data.get("statement_id", ""), status=status, - manifest=data.get("manifest"), # We'll parse this more fully if needed - result=data.get("result"), # We'll parse this more fully if needed + manifest=manifest, + result=result, ) @@ -103,3 +188,38 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) + + +@dataclass +class GetChunksResponse: + """Response from getting chunks for a statement.""" + + statement_id: str + external_links: List[ExternalLink] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": + """Create a GetChunksResponse from a dictionary.""" + external_links = [] + if "external_links" in data: + for link_data in data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers"), + ) + ) + + return cls( + statement_id=data.get("statement_id", ""), + external_links=external_links, + ) From e3ee4e4acfd7178db6a78dadce21bc6e7a52b77f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 15:24:33 +0000 Subject: [PATCH 040/204] move arrow_schema_bytes back into ExecuteResult Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 58 ++++------ src/databricks/sql/backend/types.py | 1 + src/databricks/sql/result_set.py | 4 +- tests/unit/test_thrift_backend.py | 106 ++++++++++++++++--- 4 files changed, 116 insertions(+), 53 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 4b3e827f2..d99cf2624 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -801,18 +801,16 @@ def _results_message_to_execute_response(self, resp, operation_state): if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ( - ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - ), - schema_bytes, + return ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, ) def get_execution_result( @@ -877,6 +875,7 @@ def get_execution_result( has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -886,7 +885,6 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -999,9 +997,7 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1010,7 +1006,6 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1032,9 +1027,7 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1043,7 +1036,6 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1069,9 +1061,7 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1080,7 +1070,6 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1110,9 +1099,7 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1121,7 +1108,6 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1151,9 +1137,7 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( - resp, cursor - ) + execute_response = self._handle_execute_response(resp, cursor) return ThriftResultSet( connection=cursor.connection, @@ -1162,7 +1146,6 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1176,11 +1159,10 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + execute_response = self._results_message_to_execute_response( + resp, final_operation_state + ) + return execute_response def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..fed1bc6cd 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -431,3 +431,4 @@ class ExecuteResponse: has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index e177d495f..23e0fa490 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -157,7 +157,6 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -169,10 +168,9 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = arrow_schema_bytes + self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b8de970db..dc2b9c038 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -19,7 +19,13 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ResultSet, ThriftResultSet -from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType +from databricks.sql.backend.types import ( + CommandId, + CommandState, + SessionId, + BackendType, + ExecuteResponse, +) def retry_policy_factory(): @@ -651,7 +657,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -885,7 +891,7 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -963,11 +969,11 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -1040,7 +1046,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1172,7 +1178,20 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1206,7 +1225,20 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1237,7 +1269,20 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1277,7 +1322,20 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1321,7 +1379,20 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock( + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ) + ) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -2229,7 +2300,18 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + return_value=Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_more_rows=Mock(), + results_queue=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + ), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class From f448a8f18170c3acd157810b6960605362fcfbd3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 15:59:50 +0000 Subject: [PATCH 041/204] maintain log Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d99cf2624..6f05b45a5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -915,7 +915,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod From 82ca1eefc150da88e637d25f26198fc696400dbe Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 16:01:48 +0000 Subject: [PATCH 042/204] remove un-necessary assignment Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 6f05b45a5..0ff68651e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1159,10 +1159,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - execute_response = self._results_message_to_execute_response( - resp, final_operation_state - ) - return execute_response + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) From e96a0785d188171aa79121b15c722a9dfd09cccd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 16:06:03 +0000 Subject: [PATCH 043/204] remove un-necessary tuple response Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index dc2b9c038..733ea17a5 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -929,12 +929,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) - thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, @@ -1738,9 +1735,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() From 27158b1fe5998e3ccaebf2c3a0cc5b462e1f656c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 9 Jun 2025 16:10:27 +0000 Subject: [PATCH 044/204] remove un-ncessary verbose mocking Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 75 +++---------------------------- 1 file changed, 5 insertions(+), 70 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 733ea17a5..c9cb05305 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1175,20 +1175,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1222,20 +1209,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1266,20 +1240,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1319,20 +1280,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1376,20 +1324,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock( - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - results_queue=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - ) - ) + thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() result = thrift_backend.get_columns( From dee47f7f4558a8c7336c86bbd5a20bda3f4a9787 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 03:45:23 +0000 Subject: [PATCH 045/204] filters stuff (align with JDBC) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 18 ++++++------------ src/databricks/sql/backend/sea/backend.py | 3 --- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 32fa78be4..9fa0a5535 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -49,32 +49,26 @@ def _filter_sea_result_set( Returns: A filtered SEA result set """ - # Get all remaining rows - original_index = result_set.results.cur_row_index - result_set.results.cur_row_index = 0 # Reset to beginning + # Get all remaining rows from the current position (JDBC-aligned behavior) + # Note: This will only filter rows that haven't been read yet all_rows = result_set.results.remaining_rows() # Filter rows filtered_rows = [row for row in all_rows if filter_func(row)] - # Import SeaResultSet here to avoid circular imports - from databricks.sql.result_set import SeaResultSet - - # Reuse the command_id from the original result set - command_id = result_set.command_id - - # Create an ExecuteResponse with the filtered data execute_response = ExecuteResponse( - command_id=command_id, + command_id=result_set.command_id, status=result_set.status, description=result_set.description, - has_more_rows=result_set._has_more_rows, + has_more_rows=result_set.has_more_rows, results_queue=JsonQueue(filtered_rows), has_been_closed_server_side=result_set.has_been_closed_server_side, lz4_compressed=False, is_staging_operation=False, ) + from databricks.sql.result_set import SeaResultSet + return SeaResultSet( connection=result_set.connection, execute_response=execute_response, diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 10100e86e..a54337f0c 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -66,9 +66,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths From d3200c49d87ef32184b48877d115353d51b82dd4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 05:31:55 +0000 Subject: [PATCH 046/204] move Queue construction to ResultSert Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 81 ++++++++++++-------- src/databricks/sql/backend/types.py | 6 +- src/databricks/sql/result_set.py | 24 +++++- tests/unit/test_client.py | 9 ++- tests/unit/test_fetches.py | 40 ++++++---- tests/unit/test_thrift_backend.py | 55 ++++++++++--- 6 files changed, 148 insertions(+), 67 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 0ff68651e..2e3e61ca0 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,7 +3,6 @@ import logging import math import time -import uuid import threading from typing import List, Union, Any, TYPE_CHECKING @@ -728,7 +727,7 @@ def _col_to_description(col): else: precision, scale = None, None - return col.columnName, cleaned_type, None, None, precision, scale, None + return [col.columnName, cleaned_type, None, None, precision, scale, None] @staticmethod def _hive_schema_to_description(t_table_schema): @@ -778,23 +777,6 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) @@ -806,11 +788,11 @@ def _results_message_to_execute_response(self, resp, operation_state): status=status, description=description, has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) def get_execution_result( @@ -837,9 +819,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -854,15 +833,9 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows status = self.get_query_state(command_id) @@ -871,11 +844,11 @@ def get_execution_result( status=status, description=description, has_more_rows=has_more_rows, - results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -885,6 +858,9 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -999,6 +975,10 @@ def execute_command( else: execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1006,6 +986,9 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_catalogs( @@ -1029,6 +1012,10 @@ def get_catalogs( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1036,6 +1023,9 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_schemas( @@ -1063,6 +1053,10 @@ def get_schemas( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1070,6 +1064,9 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_tables( @@ -1101,6 +1098,10 @@ def get_tables( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1108,6 +1109,9 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def get_columns( @@ -1139,6 +1143,10 @@ def get_columns( execute_response = self._handle_execute_response(resp, cursor) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1146,6 +1154,9 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, ) def _handle_execute_response(self, resp, cursor): @@ -1203,6 +1214,8 @@ def fetch_results( ) ) + from databricks.sql.utils import ResultSetQueueFactory + queue = ResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index fed1bc6cd..ba2975d7c 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,12 +423,10 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None + description: Optional[List[List[Any]]] = None has_more_rows: bool = False - results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 23e0fa490..ab3fb68f2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -157,6 +157,9 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -168,12 +171,31 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + t_row_set: The TRowSet containing result data (if available) + max_download_threads: Maximum number of download threads for cloud fetch + ssl_options: SSL options for cloud fetch """ # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + # Call parent constructor with common attributes super().__init__( connection=connection, @@ -184,7 +206,7 @@ def __init__( status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 090ec255e..63bc92fdc 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -210,6 +212,8 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) + result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -254,7 +258,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7249a59e6..18be51da8 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,6 +40,17 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( @@ -47,18 +58,16 @@ def make_dummy_result_set_from_initial_results(initial_results): status=None, has_been_closed_server_side=True, has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - results_queue=arrow_queue, + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] + + # Replace the results queue with our arrow_queue + rs.results = arrow_queue return rs @staticmethod @@ -85,6 +94,11 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( @@ -92,12 +106,8 @@ def fetch_results( status=None, has_been_closed_server_side=False, has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - results_queue=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index c9cb05305..7165c6259 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -511,10 +511,10 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): self.assertEqual( description, [ - ("column 1", "int", None, None, None, None, None), - ("column 2", "boolean", None, None, None, None, None), - ("column 2", "map", None, None, None, None, None), - ("", "struct", None, None, None, None, None), + ["column 1", "int", None, None, None, None, None], + ["column 2", "boolean", None, None, None, None, None], + ["column 2", "map", None, None, None, None, None], + ["", "struct", None, None, None, None, None], ], ) @@ -549,7 +549,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): self.assertEqual( description, [ - ("column 1", "decimal", None, None, 10, 100, None), + ["column 1", "decimal", None, None, 10, 100, None], ], ) @@ -1161,8 +1161,11 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1178,6 +1181,8 @@ def test_execute_statement_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) @@ -1195,8 +1200,11 @@ def test_execute_statement_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1212,6 +1220,8 @@ def test_get_catalogs_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet self.assertIsInstance(result, ResultSet) @@ -1226,8 +1236,11 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1243,6 +1256,8 @@ def test_get_schemas_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_schemas( Mock(), 100, @@ -1266,8 +1281,11 @@ def test_get_schemas_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1283,6 +1301,8 @@ def test_get_tables_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_tables( Mock(), 100, @@ -1310,8 +1330,11 @@ def test_get_tables_calls_client_and_handle_execute_response( ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1327,6 +1350,8 @@ def test_get_columns_calls_client_and_handle_execute_response( thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + result = thrift_backend.get_columns( Mock(), 100, @@ -2228,6 +2253,9 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", return_value=Mock( @@ -2236,15 +2264,15 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): status=Mock(), description=Mock(), has_more_rows=Mock(), - results_queue=Mock(), has_been_closed_server_side=Mock(), lz4_compressed=Mock(), is_staging_operation=Mock(), arrow_schema_bytes=Mock(), + result_format=Mock(), ), ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, mock_build_queue, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value # Iterate through each possible combination of native types (True, False and unset) @@ -2268,6 +2296,9 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) + + thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) + thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 From 8a014f01df6137685a3acd58f10852d73fba3c2f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 06:10:58 +0000 Subject: [PATCH 047/204] move description to List[Tuple] Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/backend/types.py | 2 +- src/databricks/sql/utils.py | 6 +++--- tests/unit/test_thrift_backend.py | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 2e3e61ca0..3792d4935 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -727,7 +727,7 @@ def _col_to_description(col): else: precision, scale = None, None - return [col.columnName, cleaned_type, None, None, precision, scale, None] + return (col.columnName, cleaned_type, None, None, precision, scale, None) @staticmethod def _hive_schema_to_description(t_table_schema): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index ba2975d7c..249816eab 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,7 +423,7 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[List[Any]]] = None + description: Optional[List[Tuple]] = None has_more_rows: bool = False has_been_closed_server_side: bool = False lz4_compressed: bool = True diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..d7b1b74b4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7165c6259..aae11c56c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -511,10 +511,10 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): self.assertEqual( description, [ - ["column 1", "int", None, None, None, None, None], - ["column 2", "boolean", None, None, None, None, None], - ["column 2", "map", None, None, None, None, None], - ["", "struct", None, None, None, None, None], + ("column 1", "int", None, None, None, None, None), + ("column 2", "boolean", None, None, None, None, None), + ("column 2", "map", None, None, None, None, None), + ("", "struct", None, None, None, None, None), ], ) @@ -549,7 +549,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): self.assertEqual( description, [ - ["column 1", "decimal", None, None, 10, 100, None], + ("column 1", "decimal", None, None, 10, 100, None), ], ) From 39c41ab9abf54e0fc4d1fbc8c02abe02271fb866 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 06:12:10 +0000 Subject: [PATCH 048/204] frmatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ab3fb68f2..dc72382c6 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -184,7 +184,7 @@ def __init__( results_queue = None if t_row_set and execute_response.result_format is not None: from databricks.sql.utils import ResultSetQueueFactory - + # Create the results queue using the provided format results_queue = ResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, From 2cd04dfc331b7ef8335cdca288884a951a4dc269 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 06:13:12 +0000 Subject: [PATCH 049/204] reduce diff (remove explicit tuple conversion) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 3792d4935..f2e95fb66 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -727,7 +727,7 @@ def _col_to_description(col): else: precision, scale = None, None - return (col.columnName, cleaned_type, None, None, precision, scale, None) + return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod def _hive_schema_to_description(t_table_schema): From 067a01967c4fe9b6b5e4bc83792b6457e2666c12 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 08:51:35 +0000 Subject: [PATCH 050/204] remove has_more_rows from ExecuteResponse Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 -- src/databricks/sql/backend/types.py | 1 - src/databricks/sql/result_set.py | 2 +- tests/unit/test_fetches.py | 2 -- tests/unit/test_thrift_backend.py | 14 +++++++------- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f2e95fb66..46f5ef02e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -787,7 +787,6 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, is_staging_operation=t_result_set_metadata_resp.isStagingOperation, @@ -843,7 +842,6 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 249816eab..93bd7d525 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -424,7 +424,6 @@ class ExecuteResponse: command_id: CommandId status: CommandState description: Optional[List[Tuple]] = None - has_more_rows: bool = False has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dc72382c6..fb9b417c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -205,7 +205,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, + has_more_rows=False, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 18be51da8..ba9b50aef 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -57,7 +57,6 @@ def make_dummy_result_set_from_initial_results(initial_results): command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, description=description, lz4_compressed=True, is_staging_operation=False, @@ -105,7 +104,6 @@ def fetch_results( command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, description=description, lz4_compressed=True, is_staging_operation=False, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index aae11c56c..bab9cb3ca 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1009,13 +1009,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_reads_has_more_rows_in_direct_results( + def test_handle_execute_response_creates_execute_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( - [True, False], self.execute_response_types - ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + """Test that _handle_execute_response creates an ExecuteResponse object correctly.""" + for resp_type in self.execute_response_types: + with self.subTest(resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1027,7 +1026,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=True, results=results_mock, ), closeOperation=Mock(), @@ -1047,7 +1046,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( execute_resp, Mock() ) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertIsNotNone(execute_response) + self.assertIsInstance(execute_response, ExecuteResponse) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() From 48c83e095afe26438b2da71a6bdd6be9e03d1d7d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 09:02:02 +0000 Subject: [PATCH 051/204] remove un-necessary has_more_rows aclc Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 46f5ef02e..7cdd583d5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -757,11 +757,7 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( - (not direct_results) - or (not direct_results.resultSet) - or direct_results.resultSet.hasMoreRows - ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) From 281a9e9675f5b573c87053f47c07517e2a4db2ca Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 10:33:27 +0000 Subject: [PATCH 052/204] default has_more_rows to True Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index fb9b417c1..cb6c5e1c3 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -205,7 +205,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=False, + has_more_rows=True, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, From 192901d2f51bf4764276c60bdd75a005e0562de0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 11:40:42 +0000 Subject: [PATCH 053/204] return has_more_rows from ExecResponse conversion during GetRespMetadata Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 28 ++- src/databricks/sql/result_set.py | 4 +- tests/unit/test_thrift_backend.py | 244 +++++++++---------- 3 files changed, 137 insertions(+), 139 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 7cdd583d5..ffbd2885e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -758,6 +758,12 @@ def _results_message_to_execute_response(self, resp, operation_state): direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation + has_more_rows = ( + (not direct_results) + or (not direct_results.resultSet) + or direct_results.resultSet.hasMoreRows + ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,7 +785,7 @@ def _results_message_to_execute_response(self, resp, operation_state): if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ExecuteResponse( + execute_response = ExecuteResponse( command_id=command_id, status=status, description=description, @@ -790,6 +796,8 @@ def _results_message_to_execute_response(self, resp, operation_state): result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, has_more_rows + def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -855,6 +863,7 @@ def get_execution_result( t_row_set=resp.results, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -967,7 +976,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response( + resp, cursor + ) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -983,6 +994,7 @@ def execute_command( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_catalogs( @@ -1004,7 +1016,7 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1020,6 +1032,7 @@ def get_catalogs( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_schemas( @@ -1045,7 +1058,7 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1061,6 +1074,7 @@ def get_schemas( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_tables( @@ -1090,7 +1104,7 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1106,6 +1120,7 @@ def get_tables( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def get_columns( @@ -1135,7 +1150,7 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, has_more_rows = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1151,6 +1166,7 @@ def get_columns( t_row_set=t_row_set, max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, + has_more_rows=has_more_rows, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cb6c5e1c3..9857d9e0f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -160,6 +160,7 @@ def __init__( t_row_set=None, max_download_threads: int = 10, ssl_options=None, + has_more_rows: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -174,6 +175,7 @@ def __init__( t_row_set: The TRowSet containing result data (if available) max_download_threads: Maximum number of download threads for cloud fetch ssl_options: SSL options for cloud fetch + has_more_rows: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes self._arrow_schema_bytes = execute_response.arrow_schema_bytes @@ -205,7 +207,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=True, + has_more_rows=has_more_rows, results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index bab9cb3ca..4f5e14cab 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -82,14 +82,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -100,8 +93,22 @@ def _make_type_desc(self, type): ] ) - def _make_fake_thrift_backend(self): - thrift_backend = ThriftDatabricksClient( + def _create_mock_execute_response(self): + """Create a properly mocked ExecuteResponse object with all required attributes.""" + mock_execute_response = Mock() + mock_execute_response.command_id = Mock() + mock_execute_response.status = Mock() + mock_execute_response.description = Mock() + mock_execute_response.has_been_closed_server_side = Mock() + mock_execute_response.lz4_compressed = Mock() + mock_execute_response.is_staging_operation = Mock() + mock_execute_response.arrow_schema_bytes = Mock() + mock_execute_response.result_format = Mock() + return mock_execute_response + + def _create_fake_thrift_client(self): + """Create a fake ThriftDatabricksClient without mocking any methods.""" + return ThriftDatabricksClient( "foobar", 443, "path", @@ -109,10 +116,20 @@ def _make_fake_thrift_backend(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) + + def _make_fake_thrift_backend(self): + """Create a fake ThriftDatabricksClient with mocked methods.""" + thrift_backend = self._create_fake_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() thrift_backend._create_arrow_table.return_value = (MagicMock(), Mock()) + # Mock _results_message_to_execute_response to return a tuple + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._results_message_to_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) + ) return thrift_backend def test_hive_schema_to_arrow_schema_preserves_column_names(self): @@ -558,14 +575,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() for code in error_codes: mock_error_response = Mock() @@ -602,14 +612,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -657,7 +660,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -832,14 +835,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -891,7 +887,7 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -921,21 +917,22 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = Mock(spec=ExecuteResponse) + mock_has_more_rows = True + thrift_backend._results_message_to_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._results_message_to_execute_response = Mock() - thrift_backend._handle_execute_response(execute_resp, Mock()) + result = thrift_backend._handle_execute_response(execute_resp, Mock()) thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, ) + # Verify the result is a tuple with the expected values + self.assertIsInstance(result, tuple) + self.assertEqual(result[0], mock_execute_response) + self.assertEqual(result[1], mock_has_more_rows) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): @@ -965,9 +962,12 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( - t_execute_resp, Mock() + + thrift_backend = self._create_fake_thrift_client() + + # Call the real _results_message_to_execute_response method + execute_response, _ = thrift_backend._results_message_to_execute_response( + t_execute_resp, ttypes.TOperationState.FINISHED_STATE ) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @@ -997,8 +997,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + + thrift_backend = self._create_fake_thrift_client() + thrift_backend._hive_schema_to_arrow_schema = Mock() + + # Call the real _results_message_to_execute_response method + thrift_backend._results_message_to_execute_response( + t_execute_resp, ttypes.TOperationState.FINISHED_STATE + ) self.assertEqual( hive_schema_mock, @@ -1040,14 +1046,16 @@ def test_handle_execute_response_creates_execute_response( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_fake_thrift_client() - execute_response = thrift_backend._handle_execute_response( + execute_response_tuple = thrift_backend._handle_execute_response( execute_resp, Mock() ) - self.assertIsNotNone(execute_response) - self.assertIsInstance(execute_response, ExecuteResponse) + self.assertIsNotNone(execute_response_tuple) + self.assertIsInstance(execute_response_tuple, tuple) + self.assertIsInstance(execute_response_tuple[0], ExecuteResponse) + self.assertIsInstance(execute_response_tuple[1], bool) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1178,7 +1186,11 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) + ) cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1209,15 +1221,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1245,15 +1255,12 @@ def test_get_schemas_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1290,15 +1297,12 @@ def test_get_tables_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1339,15 +1343,12 @@ def test_get_columns_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), + thrift_backend = self._make_fake_thrift_backend() + mock_execute_response = self._create_mock_execute_response() + mock_has_more_rows = True + thrift_backend._handle_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) ) - thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) @@ -1397,14 +1398,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) self.assertEqual( @@ -1415,14 +1409,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( @@ -1458,7 +1445,8 @@ def test_non_arrow_non_column_based_set_triggers_exception( tcli_service_instance.ExecuteStatement.return_value = execute_statement_resp tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - thrift_backend = self._make_fake_thrift_backend() + + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) @@ -1468,14 +1456,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1488,14 +1469,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftDatabricksClient( - "foobar", - 443, - "path", - [], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) + thrift_backend = self._create_fake_thrift_client() convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1695,7 +1669,11 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + mock_execute_response = Mock(spec=ExecuteResponse) + mock_has_more_rows = True + thrift_backend._results_message_to_execute_response = Mock( + return_value=(mock_execute_response, mock_has_more_rows) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2258,17 +2236,19 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): ) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_more_rows=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - result_format=Mock(), + return_value=( + Mock( + spec=ExecuteResponse, + command_id=Mock(), + status=Mock(), + description=Mock(), + has_been_closed_server_side=Mock(), + lz4_compressed=Mock(), + is_staging_operation=Mock(), + arrow_schema_bytes=Mock(), + result_format=Mock(), + ), + True, # has_more_rows ), ) def test_execute_command_sets_complex_type_fields_correctly( From 55f5c45a9fe18ac76839a4b8ff4955e58af18fe6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 12:38:24 +0000 Subject: [PATCH 054/204] remove unnecessary replacement Signed-off-by: varun-edachali-dbx --- tests/unit/test_fetches.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index ba9b50aef..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -64,9 +64,6 @@ def make_dummy_result_set_from_initial_results(initial_results): thrift_client=mock_thrift_backend, t_row_set=None, ) - - # Replace the results queue with our arrow_queue - rs.results = arrow_queue return rs @staticmethod From edc36b5540d178f6e52bc022eeb265122d6c7d81 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 12:41:12 +0000 Subject: [PATCH 055/204] better mocked backend naming Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 4f5e14cab..8582fd7f9 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -117,7 +117,7 @@ def _create_fake_thrift_client(self): ssl_options=SSLOptions(), ) - def _make_fake_thrift_backend(self): + def _create_mocked_thrift_client(self): """Create a fake ThriftDatabricksClient with mocked methods.""" thrift_backend = self._create_fake_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() @@ -184,7 +184,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): ) with self.assertRaises(OperationalError) as cm: - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend.open_session({}, None, None) self.assertIn( @@ -207,7 +207,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): sessionHandle=self.session_handle, ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend.open_session({}, None, None) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -917,7 +917,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = Mock(spec=ExecuteResponse) mock_has_more_rows = True thrift_backend._results_message_to_execute_response = Mock( @@ -1100,7 +1100,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( @@ -1221,7 +1221,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True @@ -1255,7 +1255,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True thrift_backend._handle_execute_response = Mock( @@ -1297,7 +1297,7 @@ def test_get_tables_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True thrift_backend._handle_execute_response = Mock( @@ -1343,7 +1343,7 @@ def test_get_columns_calls_client_and_handle_execute_response( tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() mock_execute_response = self._create_mock_execute_response() mock_has_more_rows = True thrift_backend._handle_execute_response = Mock( @@ -1655,7 +1655,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() # Create a proper CommandId from the existing operation_handle command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.cancel_command(command_id) @@ -1666,7 +1666,7 @@ def test_cancel_command_uses_active_op_handle(self, tcli_service_class): ) def test_handle_execute_response_sets_active_op_handle(self): - thrift_backend = self._make_fake_thrift_backend() + thrift_backend = self._create_mocked_thrift_client() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() mock_execute_response = Mock(spec=ExecuteResponse) From 81280e701d52609a5ad59deab63d2e24012d2002 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 12:47:06 +0000 Subject: [PATCH 056/204] remove has_more_rows test in ExecuteResponse Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 47 ------------------------------- 1 file changed, 47 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 8582fd7f9..2054cb65a 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -990,7 +990,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response op_state = ttypes.TGetOperationStatusResp( status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE, @@ -1011,52 +1010,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend._hive_schema_to_arrow_schema.call_args[0][0], ) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) - @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_creates_execute_response( - self, tcli_service_class, build_queue - ): - """Test that _handle_execute_response creates an ExecuteResponse object correctly.""" - for resp_type in self.execute_response_types: - with self.subTest(resp_type=resp_type): - tcli_service_instance = tcli_service_class.return_value - results_mock = Mock() - results_mock.startRowOffset = 0 - direct_results_message = ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ), - resultSetMetadata=self.metadata_resp, - resultSet=ttypes.TFetchResultsResp( - status=self.okay_status, - hasMoreRows=True, - results=results_mock, - ), - closeOperation=Mock(), - ) - execute_resp = resp_type( - status=self.okay_status, - directResults=direct_results_message, - operationHandle=self.operation_handle, - ) - - tcli_service_instance.GetResultSetMetadata.return_value = ( - self.metadata_resp - ) - thrift_backend = self._create_fake_thrift_client() - - execute_response_tuple = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - - self.assertIsNotNone(execute_response_tuple) - self.assertIsInstance(execute_response_tuple, tuple) - self.assertIsInstance(execute_response_tuple[0], ExecuteResponse) - self.assertIsInstance(execute_response_tuple[1], bool) - @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) From c1d3be2fadc4d1aab3f63136ddcff6e2a4a1931a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:11:36 +0000 Subject: [PATCH 057/204] introduce replacement of original has_more_rows read test Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 78 +++++++++++++------------------ 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 2054cb65a..3bdf1434d 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -82,7 +82,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -106,7 +106,7 @@ def _create_mock_execute_response(self): mock_execute_response.result_format = Mock() return mock_execute_response - def _create_fake_thrift_client(self): + def _create_thrift_client(self): """Create a fake ThriftDatabricksClient without mocking any methods.""" return ThriftDatabricksClient( "foobar", @@ -119,7 +119,7 @@ def _create_fake_thrift_client(self): def _create_mocked_thrift_client(self): """Create a fake ThriftDatabricksClient with mocked methods.""" - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -575,7 +575,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() for code in error_codes: mock_error_response = Mock() @@ -612,7 +612,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -835,7 +835,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -963,7 +963,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_get_result_set_metadata_resp ) - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() # Call the real _results_message_to_execute_response method execute_response, _ = thrift_backend._results_message_to_execute_response( @@ -997,7 +997,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() # Call the real _results_message_to_execute_response method @@ -1014,7 +1014,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_reads_has_more_rows_in_result_response( + def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): for has_more_rows, resp_type in itertools.product( @@ -1022,48 +1022,34 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( ): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value - results_mock = MagicMock() + results_mock = Mock() results_mock.startRowOffset = 0 - - execute_resp = resp_type( - status=self.okay_status, - directResults=None, - operationHandle=self.operation_handle, - ) - - fetch_results_resp = ttypes.TFetchResultsResp( - status=self.okay_status, - hasMoreRows=has_more_rows, - results=results_mock, - resultSetMetadata=ttypes.TGetResultSetMetadataResp( - resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + direct_results_message = ttypes.TSparkDirectResults( + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, ), + resultSetMetadata=self.metadata_resp, + resultSet=ttypes.TFetchResultsResp( + status=self.okay_status, + hasMoreRows=has_more_rows, + results=results_mock, + ), + closeOperation=Mock(), ) - - operation_status_resp = ttypes.TGetOperationStatusResp( + execute_resp = resp_type( status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - errorMessage="some information about the error", + directResults=direct_results_message, + operationHandle=self.operation_handle, ) - tcli_service_instance.FetchResults.return_value = fetch_results_resp - tcli_service_instance.GetOperationStatus.return_value = ( - operation_status_resp - ) tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._create_thrift_client() - thrift_backend._handle_execute_response(execute_resp, Mock()) - _, has_more_rows_resp = thrift_backend.fetch_results( - command_id=Mock(), - max_rows=1, - max_bytes=1, - expected_row_start_offset=0, - lz4_compressed=False, - arrow_schema_bytes=Mock(), - description=Mock(), + _, has_more_rows_resp = thrift_backend._handle_execute_response( + execute_resp, Mock() ) self.assertEqual(has_more_rows, has_more_rows_resp) @@ -1351,7 +1337,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) self.assertEqual( @@ -1362,7 +1348,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( @@ -1399,7 +1385,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) @@ -1409,7 +1395,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1422,7 +1408,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = self._create_fake_thrift_client() + thrift_backend = self._create_thrift_client() convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) From 5ee41367701696a2cd4f791a2633b374a36ced0c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:14:18 +0000 Subject: [PATCH 058/204] call correct method in test_use_arrow_schema Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 3bdf1434d..fc56feea6 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -966,8 +966,8 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): thrift_backend = self._create_thrift_client() # Call the real _results_message_to_execute_response method - execute_response, _ = thrift_backend._results_message_to_execute_response( - t_execute_resp, ttypes.TOperationState.FINISHED_STATE + execute_response, _ = thrift_backend._handle_execute_response( + t_execute_resp, Mock() ) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) From b881ab0823f31d709c5d76aa00d9d051506eb835 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:15:41 +0000 Subject: [PATCH 059/204] call correct method in test_fall_back_to_hive_schema Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index fc56feea6..cbde1a29b 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -965,7 +965,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): thrift_backend = self._create_thrift_client() - # Call the real _results_message_to_execute_response method execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -1000,10 +999,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() - # Call the real _results_message_to_execute_response method - thrift_backend._results_message_to_execute_response( - t_execute_resp, ttypes.TOperationState.FINISHED_STATE - ) + thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, From 53bf715a28e59043e7f692ee67b3ef5be36740a0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:17:54 +0000 Subject: [PATCH 060/204] re-introduce result response read test Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 58 +++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index cbde1a29b..b7922d729 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1050,6 +1050,64 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(has_more_rows, has_more_rows_resp) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + def test_handle_execute_response_reads_has_more_rows_in_result_response( + self, tcli_service_class, build_queue + ): + for has_more_rows, resp_type in itertools.product( + [True, False], self.execute_response_types + ): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + tcli_service_instance = tcli_service_class.return_value + results_mock = MagicMock() + results_mock.startRowOffset = 0 + + execute_resp = resp_type( + status=self.okay_status, + directResults=None, + operationHandle=self.operation_handle, + ) + + fetch_results_resp = ttypes.TFetchResultsResp( + status=self.okay_status, + hasMoreRows=has_more_rows, + results=results_mock, + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + ), + ) + + operation_status_resp = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + errorMessage="some information about the error", + ) + + tcli_service_instance.FetchResults.return_value = fetch_results_resp + tcli_service_instance.GetOperationStatus.return_value = ( + operation_status_resp + ) + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) + thrift_backend = self._create_thrift_client() + + thrift_backend._handle_execute_response(execute_resp, Mock()) + _, has_more_rows_resp = thrift_backend.fetch_results( + command_id=Mock(), + max_rows=1, + max_bytes=1, + expected_row_start_offset=0, + lz4_compressed=False, + arrow_schema_bytes=Mock(), + description=Mock(), + ) + + self.assertEqual(has_more_rows, has_more_rows_resp) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue From 45a32be5915927bce570710e0375488580041bf8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:20:54 +0000 Subject: [PATCH 061/204] simplify test Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b7922d729..c54fabf40 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -184,7 +184,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): ) with self.assertRaises(OperationalError) as cm: - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend.open_session({}, None, None) self.assertIn( @@ -207,7 +207,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): sessionHandle=self.session_handle, ) - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._create_thrift_client() thrift_backend.open_session({}, None, None) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -918,21 +918,12 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ) thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = Mock(spec=ExecuteResponse) - mock_has_more_rows = True - thrift_backend._results_message_to_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) - result = thrift_backend._handle_execute_response(execute_resp, Mock()) + thrift_backend._handle_execute_response(execute_resp, Mock()) thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, ) - # Verify the result is a tuple with the expected values - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], mock_execute_response) - self.assertEqual(result[1], mock_has_more_rows) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): From e3fe29979743c14099e9d7f88daf2b3f750121a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 13:35:16 +0000 Subject: [PATCH 062/204] remove excess fetch_results mocks Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 2 -- tests/unit/test_thrift_backend.py | 12 ------------ 2 files changed, 14 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 63bc92fdc..1f0c34025 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -213,7 +213,6 @@ def test_closing_result_set_hard_closes_commands(self): type(mock_connection).session = PropertyMock(return_value=mock_session) mock_thrift_backend.fetch_results.return_value = (Mock(), False) - result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -479,7 +478,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index c54fabf40..7a59c6256 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1177,8 +1177,6 @@ def test_execute_statement_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) @@ -1214,8 +1212,6 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet self.assertIsInstance(result, ResultSet) @@ -1247,8 +1243,6 @@ def test_get_schemas_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_schemas( Mock(), 100, @@ -1289,8 +1283,6 @@ def test_get_tables_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_tables( Mock(), 100, @@ -1335,8 +1327,6 @@ def test_get_columns_calls_client_and_handle_execute_response( ) cursor_mock = Mock() - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - result = thrift_backend.get_columns( Mock(), 100, @@ -2261,8 +2251,6 @@ def test_execute_command_sets_complex_type_fields_correctly( **complex_arg_types, ) - thrift_backend.fetch_results = Mock(return_value=(Mock(), False)) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 From e8038d3ac07ebc368f30f6c9102e578691891c75 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 15:25:19 +0000 Subject: [PATCH 063/204] more minimal changes to thrift_backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 347 ++++++++++++++++-------------- 1 file changed, 183 insertions(+), 164 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7a59c6256..5d9da0e13 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -19,13 +19,7 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ResultSet, ThriftResultSet -from databricks.sql.backend.types import ( - CommandId, - CommandState, - SessionId, - BackendType, - ExecuteResponse, -) +from databricks.sql.backend.types import CommandId, CommandState, SessionId, BackendType def retry_policy_factory(): @@ -82,7 +76,14 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -93,22 +94,8 @@ def _make_type_desc(self, type): ] ) - def _create_mock_execute_response(self): - """Create a properly mocked ExecuteResponse object with all required attributes.""" - mock_execute_response = Mock() - mock_execute_response.command_id = Mock() - mock_execute_response.status = Mock() - mock_execute_response.description = Mock() - mock_execute_response.has_been_closed_server_side = Mock() - mock_execute_response.lz4_compressed = Mock() - mock_execute_response.is_staging_operation = Mock() - mock_execute_response.arrow_schema_bytes = Mock() - mock_execute_response.result_format = Mock() - return mock_execute_response - - def _create_thrift_client(self): - """Create a fake ThriftDatabricksClient without mocking any methods.""" - return ThriftDatabricksClient( + def _make_fake_thrift_backend(self): + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -116,20 +103,10 @@ def _create_thrift_client(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - - def _create_mocked_thrift_client(self): - """Create a fake ThriftDatabricksClient with mocked methods.""" - thrift_backend = self._create_thrift_client() thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() thrift_backend._create_arrow_table.return_value = (MagicMock(), Mock()) - # Mock _results_message_to_execute_response to return a tuple - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._results_message_to_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) return thrift_backend def test_hive_schema_to_arrow_schema_preserves_column_names(self): @@ -184,7 +161,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): ) with self.assertRaises(OperationalError) as cm: - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) self.assertIn( @@ -207,7 +184,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): sessionHandle=self.session_handle, ) - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -575,7 +552,14 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) for code in error_codes: mock_error_response = Mock() @@ -612,7 +596,14 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -628,18 +619,14 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=op_status, + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -835,15 +822,23 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -887,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - + ( + execute_response, + arrow_schema_bytes, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -917,9 +912,18 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) + thrift_backend._results_message_to_execute_response.assert_called_with( execute_resp, ttypes.TOperationState.FINISHED_STATE, @@ -944,18 +948,16 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - - thrift_backend = self._create_thrift_client() - + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) + thrift_backend = self._make_fake_thrift_backend() execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -980,17 +982,17 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - - thrift_backend = self._create_thrift_client() - thrift_backend._hive_schema_to_arrow_schema = Mock() - - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) + thrift_backend = self._make_fake_thrift_backend() + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + t_execute_resp, Mock() + ) self.assertEqual( hive_schema_mock, @@ -1033,13 +1035,14 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() - _, has_more_rows_resp = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1084,7 +1087,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( tcli_service_instance.GetResultSetMetadata.return_value = ( self.metadata_resp ) - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( @@ -1152,12 +1155,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_execute_statement_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1170,18 +1171,15 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1193,28 +1191,29 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_catalogs_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = self._create_mocked_thrift_client() - - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1225,22 +1224,24 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_schemas_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1252,7 +1253,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1265,22 +1266,24 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_tables_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1291,10 +1294,10 @@ def test_get_tables_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", - table_types=["type1", "type2"], + table_types=["VIEW", "TABLE"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1303,28 +1306,30 @@ def test_get_tables_calls_client_and_handle_execute_response( self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") self.assertEqual(req.tableName, "table_pattern") - self.assertEqual(req.tableTypes, ["type1", "type2"]) + self.assertEqual(req.tableTypes, ["VIEW", "TABLE"]) # Check response handling thrift_backend._handle_execute_response.assert_called_with( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) def test_get_columns_calls_client_and_handle_execute_response( - self, mock_build_queue, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = self._create_mocked_thrift_client() - mock_execute_response = self._create_mock_execute_response() - mock_has_more_rows = True - thrift_backend._handle_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1338,7 +1343,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -1372,7 +1377,14 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) self.assertEqual( @@ -1383,7 +1395,14 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( @@ -1419,8 +1438,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( tcli_service_instance.ExecuteStatement.return_value = execute_statement_resp tcli_service_instance.GetResultSetMetadata.return_value = metadata_resp tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - - thrift_backend = self._create_thrift_client() + thrift_backend = self._make_fake_thrift_backend() with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) @@ -1430,7 +1448,14 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1443,7 +1468,14 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = self._create_thrift_client() + thrift_backend = ThriftDatabricksClient( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1629,7 +1661,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._make_fake_thrift_backend() # Create a proper CommandId from the existing operation_handle command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.cancel_command(command_id) @@ -1640,14 +1672,10 @@ def test_cancel_command_uses_active_op_handle(self, tcli_service_class): ) def test_handle_execute_response_sets_active_op_handle(self): - thrift_backend = self._create_mocked_thrift_client() + thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - mock_execute_response = Mock(spec=ExecuteResponse) - mock_has_more_rows = True - thrift_backend._results_message_to_execute_response = Mock( - return_value=(mock_execute_response, mock_has_more_rows) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,31 +2232,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() - ) - @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=( - Mock( - spec=ExecuteResponse, - command_id=Mock(), - status=Mock(), - description=Mock(), - has_been_closed_server_side=Mock(), - lz4_compressed=Mock(), - is_staging_operation=Mock(), - arrow_schema_bytes=Mock(), - result_format=Mock(), - ), - True, # has_more_rows - ), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, mock_build_queue, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] @@ -2250,7 +2270,6 @@ def test_execute_command_sets_complex_type_fields_correctly( ssl_options=SSLOptions(), **complex_arg_types, ) - thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ 0 From 2f6ec19b29dc0bffced7e96ec2ef596880aa7193 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 15:33:48 +0000 Subject: [PATCH 064/204] move back to old table types Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 5d9da0e13..61b96e523 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1294,7 +1294,7 @@ def test_get_tables_calls_client_and_handle_execute_response( catalog_name="catalog_pattern", schema_name="schema_pattern", table_name="table_pattern", - table_types=["VIEW", "TABLE"], + table_types=["type1", "type2"], ) # Verify the result is a ResultSet self.assertEqual(result, mock_result_set.return_value) @@ -1306,7 +1306,7 @@ def test_get_tables_calls_client_and_handle_execute_response( self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") self.assertEqual(req.tableName, "table_pattern") - self.assertEqual(req.tableTypes, ["VIEW", "TABLE"]) + self.assertEqual(req.tableTypes, ["type1", "type2"]) # Check response handling thrift_backend._handle_execute_response.assert_called_with( response, cursor_mock From 73bc28267f83656b7d7f82cab77721cf93ef013f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 10 Jun 2025 15:35:14 +0000 Subject: [PATCH 065/204] remove outdated arrow_schema_bytes return Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 61b96e523..a05e8cb87 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -884,7 +884,7 @@ def test_handle_execute_response_can_handle_without_direct_results( ) ( execute_response, - arrow_schema_bytes, + _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, @@ -990,9 +990,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, From e385d5b8b6f9be36183e763286f3406ca6c5c144 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 04:49:37 +0000 Subject: [PATCH 066/204] backend from cloudfetch-sea Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 375 +++++++++++------- .../sql/backend/sea/models/responses.py | 12 +- .../sql/backend/sea/utils/http_client.py | 2 +- 3 files changed, 233 insertions(+), 156 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a54337f0c..c1f21448b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,8 +1,8 @@ import logging -import re import uuid import time -from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING +import re +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, @@ -11,13 +11,26 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet + from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import ( + SessionId, + CommandId, + CommandState, + BackendType, + ExecuteResponse, +) from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions +from databricks.sql.utils import SeaResultSetQueueFactory +from databricks.sql.backend.sea.models.base import ( + ResultData, + ExternalLink, + ResultManifest, +) from databricks.sql.backend.sea.models import ( ExecuteStatementRequest, @@ -66,6 +79,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -75,6 +91,8 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" + CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -107,6 +125,7 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) + self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -263,6 +282,19 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) + @staticmethod + def is_session_configuration_parameter_supported(name: str) -> bool: + """ + Check if a session configuration parameter is supported. + + Args: + name: The name of the session configuration parameter + + Returns: + True if the parameter is supported, False otherwise + """ + return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP + @staticmethod def get_allowed_session_configurations() -> List[str]: """ @@ -273,8 +305,182 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - # == Not Implemented Operations == - # These methods will be implemented in future iterations + def _get_schema_bytes(self, sea_response) -> Optional[bytes]: + """ + Extract schema bytes from the SEA response. + + For ARROW format, we need to get the schema bytes from the first chunk. + If the first chunk is not available, we need to get it from the server. + + Args: + sea_response: The response from the SEA API + + Returns: + bytes: The schema bytes or None if not available + """ + import requests + import lz4.frame + + # Check if we have the first chunk in the response + result_data = sea_response.get("result", {}) + external_links = result_data.get("external_links", []) + + if not external_links: + return None + + # Find the first chunk (chunk_index = 0) + first_chunk = None + for link in external_links: + if link.get("chunk_index") == 0: + first_chunk = link + break + + if not first_chunk: + # Try to fetch the first chunk from the server + statement_id = sea_response.get("statement_id") + if not statement_id: + return None + + chunks_response = self.get_chunk_links(statement_id, 0) + if not chunks_response.external_links: + return None + + first_chunk = chunks_response.external_links[0].__dict__ + + # Download the first chunk to get the schema bytes + external_link = first_chunk.get("external_link") + http_headers = first_chunk.get("http_headers", {}) + + if not external_link: + return None + + # Use requests to download the first chunk + http_response = requests.get( + external_link, + headers=http_headers, + verify=self.ssl_options.tls_verify, + ) + + if http_response.status_code != 200: + raise Error(f"Failed to download schema bytes: {http_response.text}") + + # Extract schema bytes from the Arrow file + # The schema is at the beginning of the file + data = http_response.content + if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": + data = lz4.frame.decompress(data) + + # Return the schema bytes + return data + + def _results_message_to_execute_response(self, sea_response, command_id): + """ + Convert a SEA response to an ExecuteResponse and extract result data. + + Args: + sea_response: The response from the SEA API + command_id: The command ID + + Returns: + tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, + result data object, and manifest object + """ + # Extract status + status_data = sea_response.get("status", {}) + state = CommandState.from_sea_state(status_data.get("state", "")) + + # Extract description from manifest + description = None + manifest_data = sea_response.get("manifest", {}) + schema_data = manifest_data.get("schema", {}) + columns_data = schema_data.get("columns", []) + + if columns_data: + columns = [] + for col_data in columns_data: + if not isinstance(col_data, dict): + continue + + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + description = columns if columns else None + + # Extract schema bytes for Arrow format + schema_bytes = None + format = manifest_data.get("format") + if format == "ARROW_STREAM": + # For ARROW format, we need to get the schema bytes + schema_bytes = self._get_schema_bytes(sea_response) + + # Check for compression + lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" + + # Initialize result_data_obj and manifest_obj + result_data_obj = None + manifest_obj = None + + result_data = sea_response.get("result", {}) + if result_data: + # Convert external links + external_links = None + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers", {}), + ) + ) + + # Create the result data object + result_data_obj = ResultData( + data=result_data.get("data_array"), external_links=external_links + ) + + # Create the manifest object + manifest_obj = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + execute_response = ExecuteResponse( + command_id=command_id, + status=state, + description=description, + has_been_closed_server_side=False, + lz4_compressed=lz4_compressed, + is_staging_operation=False, + arrow_schema_bytes=schema_bytes, + result_format=manifest_data.get("format"), + ) + + return execute_response, result_data_obj, manifest_obj def execute_command( self, @@ -336,7 +542,7 @@ def execute_command( format=format, wait_timeout="0s" if async_op else "10s", on_wait_timeout="CONTINUE", - row_limit=max_rows if max_rows > 0 else None, + row_limit=max_rows, parameters=sea_parameters if sea_parameters else None, result_compression=result_compression, ) @@ -494,157 +700,20 @@ def get_execution_result( # Create and return a SeaResultSet from databricks.sql.result_set import SeaResultSet + # Convert the response to an ExecuteResponse and extract result data + ( + execute_response, + result_data, + manifest, + ) = self._results_message_to_execute_response(response_data, command_id) + return SeaResultSet( connection=cursor.connection, - sea_response=response_data, + execute_response=execute_response, sea_client=self, buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, + result_data=result_data, + manifest=manifest, ) - # == Metadata Operations == - - def get_catalogs( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - result = self.execute_command( - operation="SHOW CATALOGS", - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result - - def get_schemas( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") - - operation = f"SHOW SCHEMAS IN `{catalog_name}`" - - if schema_name: - operation += f" LIKE '{schema_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result - - def get_tables( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - table_name: Optional[str] = None, - table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" - if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" - ) - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" LIKE '{table_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - - # Apply client-side filtering by table_types if specified - from databricks.sql.backend.filters import ResultSetFilter - - result = ResultSetFilter.filter_tables_by_type(result, table_types) - - return result - - def get_columns( - self, - session_id: SessionId, - max_rows: int, - max_bytes: int, - cursor: "Cursor", - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - table_name: Optional[str] = None, - column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_columns") - - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" - - if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" - - if table_name: - operation += f" TABLE LIKE '{table_name}'" - - if column_name: - operation += f" LIKE '{column_name}'" - - result = self.execute_command( - operation=operation, - session_id=session_id, - max_rows=max_rows, - max_bytes=max_bytes, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - assert result is not None, "execute_command returned None in synchronous mode" - return result diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 6b5067506..d684a9c67 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -39,8 +39,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": error_code=error_data.get("error_code"), ) + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( - state=CommandState.from_sea_state(status_data.get("state", "")), + state=state, error=error, sql_state=status_data.get("sql_state"), ) @@ -119,8 +123,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": error_code=error_data.get("error_code"), ) + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( - state=CommandState.from_sea_state(status_data.get("state", "")), + state=state, error=error, sql_state=status_data.get("sql_state"), ) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..f0b931ee4 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +from typing import Callable, Dict, Any, Optional, Union, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider From 484064ef8cd24e2f6c5cf9ec268d2cfb5597ea4d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 04:51:22 +0000 Subject: [PATCH 067/204] remove filtering, metadata ops Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 154 ----------- src/databricks/sql/backend/sea/backend.py | 1 - tests/unit/test_result_set_filter.py | 246 ------------------ tests/unit/test_sea_backend.py | 302 ---------------------- 4 files changed, 703 deletions(-) delete mode 100644 src/databricks/sql/backend/filters.py delete mode 100644 tests/unit/test_result_set_filter.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py deleted file mode 100644 index 9fa0a5535..000000000 --- a/src/databricks/sql/backend/filters.py +++ /dev/null @@ -1,154 +0,0 @@ -""" -Client-side filtering utilities for Databricks SQL connector. - -This module provides filtering capabilities for result sets returned by different backends. -""" - -import logging -from typing import ( - List, - Optional, - Any, - Dict, - Callable, - TypeVar, - Generic, - cast, - TYPE_CHECKING, -) - -from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory -from databricks.sql.backend.types import ExecuteResponse, CommandId -from databricks.sql.backend.sea.models.base import ResultData - -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet, SeaResultSet - -logger = logging.getLogger(__name__) - - -class ResultSetFilter: - """ - A general-purpose filter for result sets that can be applied to any backend. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. - """ - - @staticmethod - def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": - """ - Filter a SEA result set using the provided filter function. - - Args: - result_set: The SEA result set to filter - filter_func: Function that takes a row and returns True if the row should be included - - Returns: - A filtered SEA result set - """ - # Get all remaining rows from the current position (JDBC-aligned behavior) - # Note: This will only filter rows that haven't been read yet - all_rows = result_set.results.remaining_rows() - - # Filter rows - filtered_rows = [row for row in all_rows if filter_func(row)] - - execute_response = ExecuteResponse( - command_id=result_set.command_id, - status=result_set.status, - description=result_set.description, - has_more_rows=result_set.has_more_rows, - results_queue=JsonQueue(filtered_rows), - has_been_closed_server_side=result_set.has_been_closed_server_side, - lz4_compressed=False, - is_staging_operation=False, - ) - - from databricks.sql.result_set import SeaResultSet - - return SeaResultSet( - connection=result_set.connection, - execute_response=execute_response, - sea_client=result_set.backend, - buffer_size_bytes=result_set.buffer_size_bytes, - arraysize=result_set.arraysize, - ) - - @staticmethod - def filter_by_column_values( - result_set: "ResultSet", - column_index: int, - allowed_values: List[str], - case_sensitive: bool = False, - ) -> "ResultSet": - """ - Filter a result set by values in a specific column. - - Args: - result_set: The result set to filter - column_index: The index of the column to filter on - allowed_values: List of allowed values for the column - case_sensitive: Whether to perform case-sensitive comparison - - Returns: - A filtered result set - """ - # Convert to uppercase for case-insensitive comparison if needed - if not case_sensitive: - allowed_values = [v.upper() for v in allowed_values] - - # Determine the type of result set and apply appropriate filtering - from databricks.sql.result_set import SeaResultSet - - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" - ) - return result_set - - @staticmethod - def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": - """ - Filter a result set of tables by the specified table types. - - This is a client-side filter that processes the result set after it has been - retrieved from the server. It filters out tables whose type does not match - any of the types in the table_types list. - - Args: - result_set: The original result set containing tables - table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) - - Returns: - A filtered result set containing only tables of the specified types - """ - # Default table types if none specified - DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] - valid_types = ( - table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES - ) - - # Table type is the 6th column (index 5) - return ResultSetFilter.filter_by_column_values( - result_set, 5, valid_types, case_sensitive=False - ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index c1f21448b..80066ae82 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -716,4 +716,3 @@ def get_execution_result( result_data=result_data, manifest=manifest, ) - diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py deleted file mode 100644 index e8eb2a757..000000000 --- a/tests/unit/test_result_set_filter.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Tests for the ResultSetFilter class. - -This module contains tests for the ResultSetFilter class, which provides -filtering capabilities for result sets returned by different backends. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.backend.filters import ResultSetFilter -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestResultSetFilter: - """Test suite for the ResultSetFilter class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def sea_response_with_tables(self): - """Create a sample SEA response with table data based on the server schema.""" - return { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 7, - "columns": [ - {"name": "namespace", "type_text": "STRING", "position": 0}, - {"name": "tableName", "type_text": "STRING", "position": 1}, - {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, - {"name": "information", "type_text": "STRING", "position": 3}, - {"name": "catalogName", "type_text": "STRING", "position": 4}, - {"name": "tableType", "type_text": "STRING", "position": 5}, - {"name": "remarks", "type_text": "STRING", "position": 6}, - ], - }, - "total_row_count": 4, - }, - "result": { - "data_array": [ - ["schema1", "table1", False, None, "catalog1", "TABLE", None], - ["schema1", "table2", False, None, "catalog1", "VIEW", None], - [ - "schema1", - "table3", - False, - None, - "catalog1", - "SYSTEM TABLE", - None, - ], - ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], - ] - }, - } - - @pytest.fixture - def sea_result_set_with_tables( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Create a SeaResultSet with table data.""" - # Create a deep copy of the response to avoid test interference - import copy - - sea_response_copy = copy.deepcopy(sea_response_with_tables) - - return SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy, - buffer_size_bytes=1000, - arraysize=100, - ) - - def test_filter_tables_by_type_default(self, sea_result_set_with_tables): - """Test filtering tables by type with default types.""" - # Default types are TABLE, VIEW, SYSTEM TABLE - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables - ) - - # Verify that only the default types are included - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): - """Test filtering tables by type with custom types.""" - # Filter for only TABLE and EXTERNAL - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] - ) - - # Verify that only the specified types are included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "EXTERNAL" in table_types - assert "VIEW" not in table_types - assert "SYSTEM TABLE" not in table_types - - def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): - """Test that table type filtering is case-insensitive.""" - # Filter for lowercase "table" and "view" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=["table", "view"] - ) - - # Verify that the matching types are included despite case differences - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" not in table_types - assert "EXTERNAL" not in table_types - - def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): - """Test filtering tables with an empty type list (should use defaults).""" - filtered_result = ResultSetFilter.filter_tables_by_type( - sea_result_set_with_tables, table_types=[] - ) - - # Verify that default types are used - assert len(filtered_result._response["result"]["data_array"]) == 3 - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - assert "SYSTEM TABLE" in table_types - assert "EXTERNAL" not in table_types - - def test_filter_by_column_values(self, sea_result_set_with_tables): - """Test filtering by values in a specific column.""" - # Filter by namespace in column index 0 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] - ) - - # All rows have schema1 in namespace, so all should be included - assert len(filtered_result._response["result"]["data_array"]) == 4 - - # Filter by table name in column index 1 - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, - column_index=1, - allowed_values=["table1", "table3"], - ) - - # Only rows with table1 or table3 should be included - assert len(filtered_result._response["result"]["data_array"]) == 2 - table_names = [ - row[1] for row in filtered_result._response["result"]["data_array"] - ] - assert "table1" in table_names - assert "table3" in table_names - assert "table2" not in table_names - assert "table4" not in table_names - - def test_filter_by_column_values_case_sensitive( - self, mock_connection, mock_sea_client, sea_response_with_tables - ): - """Test case-sensitive filtering by column values.""" - import copy - - # Create a fresh result set for the first test - sea_response_copy1 = copy.deepcopy(sea_response_with_tables) - result_set1 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy1, - buffer_size_bytes=1000, - arraysize=100, - ) - - # First test: Case-sensitive filtering with lowercase values (should find no matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set1, - column_index=5, # tableType column - allowed_values=["table", "view"], # lowercase - case_sensitive=True, - ) - - # Verify no matches with lowercase values - assert len(filtered_result._response["result"]["data_array"]) == 0 - - # Create a fresh result set for the second test - sea_response_copy2 = copy.deepcopy(sea_response_with_tables) - result_set2 = SeaResultSet( - connection=mock_connection, - sea_client=mock_sea_client, - sea_response=sea_response_copy2, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Second test: Case-sensitive filtering with correct case (should find matches) - filtered_result = ResultSetFilter.filter_by_column_values( - result_set2, - column_index=5, # tableType column - allowed_values=["TABLE", "VIEW"], # correct case - case_sensitive=True, - ) - - # Verify matches with correct case - assert len(filtered_result._response["result"]["data_array"]) == 2 - - # Extract the table types from the filtered results - table_types = [ - row[5] for row in filtered_result._response["result"]["data_array"] - ] - assert "TABLE" in table_types - assert "VIEW" in table_types - - def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): - """Test filtering with a column index that's out of bounds.""" - # Filter by column index 10 (out of bounds) - filtered_result = ResultSetFilter.filter_by_column_values( - sea_result_set_with_tables, column_index=10, allowed_values=["value"] - ) - - # No rows should match - assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index c2078f731..2fa362b8e 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -546,305 +546,3 @@ def test_get_execution_result( assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( "test-statement-123" ) - - # Tests for metadata operations - - def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): - """Test getting catalogs.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method - result = sea_client.get_catalogs( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["lz4_compression"] is False - assert kwargs["use_cloud_fetch"] is False - assert kwargs["parameters"] == [] - assert kwargs["async_op"] is False - assert kwargs["enforce_embedded_schema_correctness"] is False - - def test_get_schemas_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_with_catalog_and_schema_pattern( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting schemas with catalog and schema pattern.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog and schema pattern - result = sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting schemas without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_schemas( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - ) - - assert "Catalog name is required for get_schemas" in str(excinfo.value) - - def test_get_tables_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_all_catalogs( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables from all catalogs using wildcard.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with wildcard catalog - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="*", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_with_schema_and_table_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting tables with schema and table patterns.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with catalog, schema, and table patterns - result = sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting tables without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_tables" in str(excinfo.value) - - def test_get_columns_with_catalog_only( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with only catalog name.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with only catalog - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_with_all_patterns( - self, sea_client, mock_cursor, sea_session_id - ): - """Test getting columns with all patterns specified.""" - # Mock execute_command to verify it's called with the right parameters - with patch.object( - sea_client, "execute_command", return_value="mock_result_set" - ) as mock_execute: - # Call the method with all patterns - result = sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name="test_catalog", - schema_name="test_schema", - table_name="test_table", - column_name="test_column", - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify execute_command was called with the right parameters - mock_execute.assert_called_once() - args, kwargs = mock_execute.call_args - assert ( - kwargs["operation"] - == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" - ) - assert kwargs["session_id"] == sea_session_id - assert kwargs["max_rows"] == 100 - assert kwargs["max_bytes"] == 1000 - assert kwargs["cursor"] == mock_cursor - assert kwargs["async_op"] is False - - def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): - """Test getting columns without a catalog name raises an error.""" - with pytest.raises(ValueError) as excinfo: - sea_client.get_columns( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - catalog_name=None, # No catalog name - schema_name="test_schema", - table_name="test_table", - ) - - assert "Catalog name is required for get_columns" in str(excinfo.value) From 030edf8df3db487b7af8d910ee51240d1339229e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 04:55:56 +0000 Subject: [PATCH 068/204] raise NotImplementedErrror for metadata ops Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 57 +++++++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 80066ae82..b1ad7cf76 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,8 +1,7 @@ import logging -import uuid import time import re -from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set +from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, @@ -23,9 +22,7 @@ ) from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions -from databricks.sql.utils import SeaResultSetQueueFactory from databricks.sql.backend.sea.models.base import ( ResultData, ExternalLink, @@ -716,3 +713,55 @@ def get_execution_result( result_data=result_data, manifest=manifest, ) + + # == Metadata Operations == + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + raise NotImplementedError("get_catalogs is not implemented for SEA backend") + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + raise NotImplementedError("get_schemas is not implemented for SEA backend") + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_tables is not implemented for SEA backend") + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_columns is not implemented for SEA backend") From 4e07f1ee60a163e5fd623b28ad703ffde1bf0ce2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:02:24 +0000 Subject: [PATCH 069/204] align SeaResultSet with new structure Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 58 +++++++++++++++++++++++-------- tests/unit/test_sea_result_set.py | 2 +- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..d6f6be3bd 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -19,7 +19,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -41,10 +41,11 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, + lz4_compressed: bool = False, + arrow_schema_bytes: bytes = b"", ): """ A ResultSet manages the results of a single command. @@ -72,9 +73,10 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation + self.lz4_compressed = lz4_compressed + self._arrow_schema_bytes = arrow_schema_bytes def __iter__(self): while True: @@ -157,7 +159,10 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + has_more_rows: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -169,12 +174,30 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set + t_row_set: The TRowSet containing result data (if available) + max_download_threads: Maximum number of download threads for cloud fetch + ssl_options: SSL options for cloud fetch + has_more_rows: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self.lz4_compressed = execute_response.lz4_compressed + self.has_more_rows = has_more_rows + + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ThriftResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ThriftResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) # Call parent constructor with common attributes super().__init__( @@ -185,10 +208,11 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, ) # Initialize results queue if not provided @@ -419,7 +443,7 @@ def map_col_type(type_): class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" + """ResultSet implementation for SEA backend.""" def __init__( self, @@ -428,17 +452,20 @@ def __init__( sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, + result_data=None, + manifest=None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. Args: connection: The parent connection + execute_response: Response from the execute command sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) + result_data: Result data from SEA response (optional) + manifest: Manifest from SEA response (optional) """ super().__init__( @@ -449,15 +476,15 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, ) def _fill_results_buffer(self): """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") def fetchone(self) -> Optional[Row]: """ @@ -480,6 +507,7 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ + raise NotImplementedError("fetchall is not implemented for SEA backend") def fetchmany_arrow(self, size: int) -> Any: diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 02421a915..b691872af 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -197,4 +197,4 @@ def test_fill_results_buffer_not_implemented( with pytest.raises( NotImplementedError, match="fetchone is not implemented for SEA backend" ): - result_set._fill_results_buffer() \ No newline at end of file + result_set._fill_results_buffer() From 65e7c6be97f94e6db0031c1501ebcb7f0c43fc9c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:05:25 +0000 Subject: [PATCH 070/204] correct sea res set tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 6 ++++-- tests/unit/test_sea_result_set.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index d6f6be3bd..3ff0cc378 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -19,7 +19,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue +from databricks.sql.utils import ColumnTable, ColumnQueue from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -484,7 +484,9 @@ def __init__( def _fill_results_buffer(self): """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") + raise NotImplementedError( + "_fill_results_buffer is not implemented for SEA backend" + ) def fetchone(self) -> Optional[Row]: """ diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index b691872af..d5d8a3667 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -195,6 +195,7 @@ def test_fill_results_buffer_not_implemented( ) with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, + match="_fill_results_buffer is not implemented for SEA backend", ): result_set._fill_results_buffer() From 30f82666804d0104bb419836def6b56b5dda3f8e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:10:50 +0000 Subject: [PATCH 071/204] add metadata commands Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 167 ++++++++++++++++++++++ src/databricks/sql/backend/sea/backend.py | 103 ++++++++++++- tests/unit/test_filters.py | 120 ++++++++++++++++ 3 files changed, 386 insertions(+), 4 deletions(-) create mode 100644 src/databricks/sql/backend/filters.py create mode 100644 tests/unit/test_filters.py diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 000000000..2c0105aee --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,167 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Dict, + Callable, + TypeVar, + Generic, + cast, + TYPE_CHECKING, +) + +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse, CommandId +from databricks.sql.backend.sea.models.base import ResultData +from databricks.sql.backend.sea.backend import SeaDatabricksClient + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet, SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Get all remaining rows + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Import SeaResultSet here to avoid circular imports + from databricks.sql.result_set import SeaResultSet + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse with the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=result_set.lz4_compressed, + arrow_schema_bytes=result_set._arrow_schema_bytes, + is_staging_operation=False, + ) + + # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData + + result_data = ResultData(data=filtered_rows, external_links=None) + + # Create a new SeaResultSet with the filtered data + filtered_result_set = SeaResultSet( + connection=result_set.connection, + execute_response=execute_response, + sea_client=cast(SeaDatabricksClient, result_set.backend), + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + result_data=result_data, + ) + + return filtered_result_set + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + from databricks.sql.result_set import SeaResultSet + + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=True + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b1ad7cf76..2807975cd 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -724,7 +724,20 @@ def get_catalogs( cursor: "Cursor", ) -> "ResultSet": """Get available catalogs by executing 'SHOW CATALOGS'.""" - raise NotImplementedError("get_catalogs is not implemented for SEA backend") + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -736,7 +749,28 @@ def get_schemas( schema_name: Optional[str] = None, ) -> "ResultSet": """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - raise NotImplementedError("get_schemas is not implemented for SEA backend") + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -750,7 +784,41 @@ def get_tables( table_types: Optional[List[str]] = None, ) -> "ResultSet": """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_tables is not implemented for SEA backend") + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -764,4 +832,31 @@ def get_columns( column_name: Optional[str] = None, ) -> "ResultSet": """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_columns is not implemented for SEA backend") + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result \ No newline at end of file diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 000000000..49bd1c328 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,120 @@ +""" +Tests for the ResultSetFilter class. +""" + +import unittest +from unittest.mock import MagicMock, patch +import sys +from typing import List, Dict, Any + +# Add the necessary path to import the filter module +sys.path.append("/home/varun.edachali/conn/databricks-sql-python/src") + +from databricks.sql.backend.filters import ResultSetFilter + + +class TestResultSetFilter(unittest.TestCase): + """Tests for the ResultSetFilter class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock SeaResultSet + self.mock_sea_result_set = MagicMock() + self.mock_sea_result_set._response = { + "result": { + "data_array": [ + ["catalog1", "schema1", "table1", "TABLE", ""], + ["catalog1", "schema1", "table2", "VIEW", ""], + ["catalog1", "schema1", "table3", "SYSTEM TABLE", ""], + ["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""], + ], + "row_count": 4, + } + } + + # Set up the connection and other required attributes + self.mock_sea_result_set.connection = MagicMock() + self.mock_sea_result_set.backend = MagicMock() + self.mock_sea_result_set.buffer_size_bytes = 1000 + self.mock_sea_result_set.arraysize = 100 + self.mock_sea_result_set.statement_id = "test-statement-id" + + # Create a mock CommandId + from databricks.sql.backend.types import CommandId, BackendType + + mock_command_id = CommandId(BackendType.SEA, "test-statement-id") + self.mock_sea_result_set.command_id = mock_command_id + + self.mock_sea_result_set.status = MagicMock() + self.mock_sea_result_set.description = [ + ("catalog_name", "string", None, None, None, None, True), + ("schema_name", "string", None, None, None, None, True), + ("table_name", "string", None, None, None, None, True), + ("table_type", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ] + self.mock_sea_result_set.has_been_closed_server_side = False + + def test_filter_tables_by_type(self): + """Test filtering tables by type.""" + # Test with specific table types + table_types = ["TABLE", "VIEW"] + + # Make the mock_sea_result_set appear to be a SeaResultSet + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + # Set up the mock to return a new mock when instantiated + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + result = ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + def test_filter_tables_by_type_case_insensitive(self): + """Test filtering tables by type with case insensitivity.""" + # Test with lowercase table types + table_types = ["table", "view"] + + # Make the mock_sea_result_set appear to be a SeaResultSet + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + # Set up the mock to return a new mock when instantiated + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + result = ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + def test_filter_tables_by_type_default(self): + """Test filtering tables by type with default types.""" + # Make the mock_sea_result_set appear to be a SeaResultSet + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + # Set up the mock to return a new mock when instantiated + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + result = ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, None + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + +if __name__ == "__main__": + unittest.main() From 033ae73440dad3295ac097da5809eff4563be7b0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:12:04 +0000 Subject: [PATCH 072/204] formatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2807975cd..1e4eb3253 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -859,4 +859,4 @@ def get_columns( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" - return result \ No newline at end of file + return result From 33821f46f0531fbc2bb08dc28002c33b46e0f485 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 05:41:54 +0000 Subject: [PATCH 073/204] add metadata command unit tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 1 - tests/unit/test_sea_backend.py | 442 ++++++++++++++++++++++++++ 2 files changed, 442 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 2c0105aee..ec91d87da 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -17,7 +17,6 @@ TYPE_CHECKING, ) -from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import ExecuteResponse, CommandId from databricks.sql.backend.sea.models.base import ResultData from databricks.sql.backend.sea.backend import SeaDatabricksClient diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2fa362b8e..0b6f10803 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -546,3 +546,445 @@ def test_get_execution_result( assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( "test-statement-123" ) + + # Tests for metadata commands + + def test_get_catalogs( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting catalogs metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + def test_get_schemas( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting schemas metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + # Test case 1: With catalog name only + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW SCHEMAS IN `test_catalog`", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 2: With catalog name and schema pattern + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema%", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema%'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 3: Missing catalog name should raise error + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, + ) + + assert "Catalog name is required" in str(excinfo.value) + + def test_get_tables( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting tables metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + # Test case 1: With catalog name only + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the get_tables method to avoid import errors + original_get_tables = sea_client.get_tables + try: + # Replace get_tables with a simple version that doesn't use ResultSetFilter + def mock_get_tables( + session_id, + max_rows, + max_bytes, + cursor, + catalog_name, + schema_name=None, + table_name=None, + table_types=None, + ): + if catalog_name is None: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + return sea_client.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + sea_client.get_tables = mock_get_tables + + # Call the method + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN CATALOG `test_catalog`", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 2: With catalog and schema name + mock_execute.reset_mock() + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 3: With catalog, schema, and table name + mock_execute.reset_mock() + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table%", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table%'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 4: With wildcard catalog + mock_execute.reset_mock() + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 5: Missing catalog name should raise error + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, + ) + + assert "Catalog name is required" in str(excinfo.value) + finally: + # Restore the original method + sea_client.get_tables = original_get_tables + + def test_get_columns( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test getting columns metadata.""" + # Set up mock for execute_command + mock_result_set = Mock() + + # Test case 1: With catalog name only + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog`", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 2: With catalog and schema name + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 3: With catalog, schema, and table name + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 4: With catalog, schema, table, and column name + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call the method + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="col%", + ) + + # Verify the result + assert result == mock_result_set + + # Verify execute_command was called with correct parameters + mock_execute.assert_called_once_with( + operation="SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'col%'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Test case 5: Missing catalog name should raise error + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, + ) + + assert "Catalog name is required" in str(excinfo.value) From 71b451a53216ea5617933ab007792e3b9ff98488 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 06:09:17 +0000 Subject: [PATCH 074/204] minimal fetch phase intro Signed-off-by: varun-edachali-dbx --- .../experimental/tests/test_sea_sync_query.py | 3 + src/databricks/sql/backend/thrift_backend.py | 10 +- src/databricks/sql/result_set.py | 120 ++++++++++++++++-- src/databricks/sql/utils.py | 68 +++++++++- 4 files changed, 186 insertions(+), 15 deletions(-) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 07be8aafc..f44246fad 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -122,6 +122,9 @@ def test_sea_sync_query_without_cloud_fetch(): cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch disabled") + rows = cursor.fetchall() + logger.info(f"Rows: {rows}") + # Close resources cursor.close() connection.close() diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..da9e617f7 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -42,11 +42,11 @@ ) from databricks.sql.utils import ( - ResultSetQueueFactory, + ThriftResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, - ResultSetQueueFactory, + ThriftResultSetQueueFactory, convert_arrow_based_set_to_arrow_table, convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, @@ -784,7 +784,7 @@ def _results_message_to_execute_response(self, resp, operation_state): assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata - arrow_queue_opt = ResultSetQueueFactory.build_queue( + arrow_queue_opt = ThriftResultSetQueueFactory.build_queue( row_set_type=t_result_set_metadata_resp.resultFormat, t_row_set=direct_results.resultSet.results, arrow_schema_bytes=schema_bytes, @@ -857,7 +857,7 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=schema_bytes, @@ -1225,7 +1225,7 @@ def fetch_results( ) ) - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..900fe1786 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -6,6 +6,7 @@ import pandas from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest try: import pyarrow @@ -19,7 +20,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -441,6 +442,14 @@ def __init__( sea_response: Direct SEA response (legacy style) """ + queue = SeaResultSetQueueFactory.build_queue( + sea_result_data=execute_response.results_data, + manifest=execute_response.results_manifest, + statement_id=execute_response.command_id.to_sea_statement_id(), + description=execute_response.description, + schema_bytes=execute_response.arrow_schema_bytes, + ) + super().__init__( connection=connection, backend=sea_client, @@ -450,22 +459,69 @@ def __init__( status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + results_queue=queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) + + def _convert_to_row_objects(self, rows): + """ + Convert raw data rows to Row objects with named columns based on description. + + Args: + rows: List of raw data rows + + Returns: + List of Row objects with named columns + """ + if not self.description or not rows: + return rows + + column_names = [col[0] for col in self.description] + ResultRow = Row(*column_names) + return [ResultRow(*row) for row in rows] def _fill_results_buffer(self): """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") + return None + + def _convert_rows_to_arrow_table(self, rows): + """Convert rows to Arrow table.""" + if not self.description: + return pyarrow.Table.from_pylist([]) + + # Create dict of column data + column_data = {} + column_names = [col[0] for col in self.description] + + for i, name in enumerate(column_names): + column_data[name] = [row[i] for row in rows] + + return pyarrow.Table.from_pydict(column_data) + + def _create_empty_arrow_table(self): + """Create an empty Arrow table with the correct schema.""" + if not self.description: + return pyarrow.Table.from_pylist([]) + + column_names = [col[0] for col in self.description] + return pyarrow.Table.from_pydict({name: [] for name in column_names}) def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") + if isinstance(self.results, JsonQueue): + rows = self.results.next_n_rows(1) + if not rows: + return None + + # Convert to Row object + converted_rows = self._convert_to_row_objects(rows) + return converted_rows[0] if converted_rows else None + else: + raise NotImplementedError("Unsupported queue type") def fetchmany(self, size: Optional[int] = None) -> List[Row]: """ @@ -473,19 +529,65 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: An empty sequence is returned when no more rows are available. """ + if size is None: + size = self.arraysize + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + # Note: We check for the specific queue type to maintain consistency with ThriftResultSet + if isinstance(self.results, JsonQueue): + rows = self.results.next_n_rows(size) + self._next_row_index += len(rows) - raise NotImplementedError("fetchmany is not implemented for SEA backend") + # Convert to Row objects + return self._convert_to_row_objects(rows) + else: + raise NotImplementedError("Unsupported queue type") def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ - raise NotImplementedError("fetchall is not implemented for SEA backend") + # Note: We check for the specific queue type to maintain consistency with ThriftResultSet + if isinstance(self.results, JsonQueue): + rows = self.results.remaining_rows() + self._next_row_index += len(rows) + + # Convert to Row objects + return self._convert_to_row_objects(rows) + else: + raise NotImplementedError("Unsupported queue type") def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + if not pyarrow: + raise ImportError("PyArrow is required for Arrow support") + + if isinstance(self.results, JsonQueue): + rows = self.fetchmany(size) + if not rows: + # Return empty Arrow table with schema + return self._create_empty_arrow_table() + + # Convert rows to Arrow table + return self._convert_rows_to_arrow_table(rows) + else: + raise NotImplementedError("Unsupported queue type") def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") + if not pyarrow: + raise ImportError("PyArrow is required for Arrow support") + + if isinstance(self.results, JsonQueue): + rows = self.fetchall() + if not rows: + # Return empty Arrow table with schema + return self._create_empty_arrow_table() + + # Convert rows to Arrow table + return self._convert_rows_to_arrow_table(rows) + else: + raise NotImplementedError("Unsupported queue type") + diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..6e14287ac 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,6 +13,9 @@ import lz4.frame +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + try: import pyarrow except ImportError: @@ -48,7 +51,7 @@ def remaining_rows(self): pass -class ResultSetQueueFactory(ABC): +class ThriftResultSetQueueFactory(ABC): @staticmethod def build_queue( row_set_type: TSparkRowSetType, @@ -106,6 +109,69 @@ def build_queue( else: raise AssertionError("Row set type is not valid") +class SeaResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + sea_result_data: ResultData, + manifest: Optional[ResultManifest], + statement_id: str, + description: Optional[List[Tuple[Any, ...]]] = None, + schema_bytes: Optional[bytes] = None, + max_download_threads: Optional[int] = None, + ssl_options: Optional[SSLOptions] = None, + sea_client: Optional["SeaDatabricksClient"] = None, + lz4_compressed: bool = False, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue for SEA backend. + + Args: + sea_result_data (ResultData): Result data from SEA response + manifest (ResultManifest): Manifest from SEA response + statement_id (str): Statement ID for the query + description (List[List[Any]]): Column descriptions + schema_bytes (bytes): Arrow schema bytes + max_download_threads (int): Maximum number of download threads + ssl_options (SSLOptions): SSL options for downloads + sea_client (SeaDatabricksClient): SEA client for fetching additional links + lz4_compressed (bool): Whether the data is LZ4 compressed + + Returns: + ResultSetQueue: The appropriate queue for the result data + """ + + if sea_result_data.data is not None: + # INLINE disposition with JSON_ARRAY format + return JsonQueue(sea_result_data.data) + elif sea_result_data.external_links is not None: + # EXTERNAL_LINKS disposition + raise NotImplementedError("EXTERNAL_LINKS disposition is not implemented for SEA backend") + else: + # Empty result set + return JsonQueue([]) + + +class JsonQueue(ResultSetQueue): + """Queue implementation for JSON_ARRAY format data.""" + + def __init__(self, data_array): + """Initialize with JSON array data.""" + self.data_array = data_array + self.cur_row_index = 0 + self.n_valid_rows = len(data_array) + + def next_n_rows(self, num_rows): + """Get the next n rows from the data array.""" + length = min(num_rows, self.n_valid_rows - self.cur_row_index) + slice = self.data_array[self.cur_row_index : self.cur_row_index + length] + self.cur_row_index += length + return slice + + def remaining_rows(self): + """Get all remaining rows from the data array.""" + slice = self.data_array[self.cur_row_index :] + self.cur_row_index += len(slice) + return slice class ColumnTable: def __init__(self, column_table, column_names): From c038d5a17d157bde12555b5dfcbb7079a803b8d0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 06:33:43 +0000 Subject: [PATCH 075/204] working JSON + INLINE Signed-off-by: varun-edachali-dbx --- .../experimental/tests/test_sea_metadata.py | 12 ++++- src/databricks/sql/result_set.py | 46 ++++++++++++------- src/databricks/sql/utils.py | 6 ++- 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index c715e5984..24b006c62 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -56,26 +56,34 @@ def test_sea_metadata(): cursor = connection.cursor() logger.info("Fetching catalogs...") cursor.catalogs() + rows = cursor.fetchall() + logger.info(f"Rows: {rows}") logger.info("Successfully fetched catalogs") # Test schemas logger.info(f"Fetching schemas for catalog '{catalog}'...") cursor.schemas(catalog_name=catalog) + rows = cursor.fetchall() + logger.info(f"Rows: {rows}") logger.info("Successfully fetched schemas") # Test tables logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") cursor.tables(catalog_name=catalog, schema_name="default") + rows = cursor.fetchall() + logger.info(f"Rows: {rows}") logger.info("Successfully fetched tables") # Test columns for a specific table # Using a common table that should exist in most environments logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." ) cursor.columns( - catalog_name=catalog, schema_name="default", table_name="information_schema" + catalog_name=catalog, schema_name="default", table_name="customer" ) + rows = cursor.fetchall() + logger.info(f"Rows: {rows}") logger.info("Successfully fetched columns") # Close resources diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ab32468f7..ece357f33 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,7 +20,12 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue, SeaResultSetQueueFactory +from databricks.sql.utils import ( + ColumnTable, + ColumnQueue, + JsonQueue, + SeaResultSetQueueFactory, +) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -469,8 +474,8 @@ def __init__( sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - result_data=None, - manifest=None, + result_data: Optional[ResultData] = None, + manifest: Optional[ResultManifest] = None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -485,13 +490,17 @@ def __init__( manifest: Manifest from SEA response (optional) """ - queue = SeaResultSetQueueFactory.build_queue( - sea_result_data=execute_response.results_data, - manifest=execute_response.results_manifest, - statement_id=execute_response.command_id.to_sea_statement_id(), - description=execute_response.description, - schema_bytes=execute_response.arrow_schema_bytes, - ) + if result_data: + queue = SeaResultSetQueueFactory.build_queue( + sea_result_data=result_data, + manifest=manifest, + statement_id=execute_response.command_id.to_sea_statement_id(), + description=execute_response.description, + schema_bytes=execute_response.arrow_schema_bytes, + ) + else: + logger.warning("No result data provided for SEA result set") + queue = JsonQueue([]) super().__init__( connection=connection, @@ -501,12 +510,13 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, + results_queue=queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, arrow_schema_bytes=execute_response.arrow_schema_bytes, ) - + def _convert_to_row_objects(self, rows): """ Convert raw data rows to Row objects with named columns based on description. @@ -526,9 +536,7 @@ def _convert_to_row_objects(self, rows): def _fill_results_buffer(self): """Fill the results buffer from the backend.""" - raise NotImplementedError( - "_fill_results_buffer is not implemented for SEA backend" - ) + return None def fetchone(self) -> Optional[Row]: """ @@ -572,8 +580,15 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ + # Note: We check for the specific queue type to maintain consistency with ThriftResultSet + if isinstance(self.results, JsonQueue): + rows = self.results.remaining_rows() + self._next_row_index += len(rows) - raise NotImplementedError("fetchall is not implemented for SEA backend") + # Convert to Row objects + return self._convert_to_row_objects(rows) + else: + raise NotImplementedError("Unsupported queue type") def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" @@ -606,4 +621,3 @@ def fetchall_arrow(self) -> Any: return self._convert_rows_to_arrow_table(rows) else: raise NotImplementedError("Unsupported queue type") - diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c415d2127..d3f2d9ee3 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -109,6 +109,7 @@ def build_queue( else: raise AssertionError("Row set type is not valid") + class SeaResultSetQueueFactory(ABC): @staticmethod def build_queue( @@ -145,7 +146,9 @@ def build_queue( return JsonQueue(sea_result_data.data) elif sea_result_data.external_links is not None: # EXTERNAL_LINKS disposition - raise NotImplementedError("EXTERNAL_LINKS disposition is not implemented for SEA backend") + raise NotImplementedError( + "EXTERNAL_LINKS disposition is not implemented for SEA backend" + ) else: # Empty result set return JsonQueue([]) @@ -173,6 +176,7 @@ def remaining_rows(self): self.cur_row_index += len(slice) return slice + class ColumnTable: def __init__(self, column_table, column_names): self.column_table = column_table From 3e22c6c4f297a3c83dbebba7c57e3bc8c0c5fe9a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 06:34:34 +0000 Subject: [PATCH 076/204] change to valid table name Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index c715e5984..394c48b24 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -74,7 +74,7 @@ def test_sea_metadata(): f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." ) cursor.columns( - catalog_name=catalog, schema_name="default", table_name="information_schema" + catalog_name=catalog, schema_name="default", table_name="customer" ) logger.info("Successfully fetched columns") From 716304b99d08e5d399d5c1a22628ce5fe3dc7a9c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 08:20:47 +0000 Subject: [PATCH 077/204] rmeove redundant queue init Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 4 +- src/databricks/sql/result_set.py | 182 ++++++--- tests/unit/test_sea_backend.py | 2 +- tests/unit/test_sea_result_set.py | 371 ++++++++++++++++--- tests/unit/test_thrift_backend.py | 9 +- 5 files changed, 457 insertions(+), 111 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f0a53e695..fc0adf915 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1224,9 +1224,9 @@ def fetch_results( ) ) - from databricks.sql.utils import ResultSetQueueFactory + from databricks.sql.utils import ThriftResultSetQueueFactory - queue = ResultSetQueueFactory.build_queue( + queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ece357f33..bd5897fb7 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -51,7 +51,7 @@ def __init__( description=None, is_staging_operation: bool = False, lz4_compressed: bool = False, - arrow_schema_bytes: bytes = b"", + arrow_schema_bytes: Optional[bytes] = b"", ): """ A ResultSet manages the results of a single command. @@ -205,22 +205,6 @@ def __init__( ssl_options=ssl_options, ) - # Build the results queue if t_row_set is provided - results_queue = None - if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory - - # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( - row_set_type=execute_response.result_format, - t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", - max_download_threads=max_download_threads, - lz4_compressed=execute_response.lz4_compressed, - description=execute_response.description, - ssl_options=ssl_options, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, @@ -543,16 +527,13 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ - if isinstance(self.results, JsonQueue): - rows = self.results.next_n_rows(1) - if not rows: - return None + rows = self.results.next_n_rows(1) + if not rows: + return None - # Convert to Row object - converted_rows = self._convert_to_row_objects(rows) - return converted_rows[0] if converted_rows else None - else: - raise NotImplementedError("Unsupported queue type") + # Convert to Row object + converted_rows = self._convert_to_row_objects(rows) + return converted_rows[0] if converted_rows else None def fetchmany(self, size: Optional[int] = None) -> List[Row]: """ @@ -566,58 +547,141 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - # Note: We check for the specific queue type to maintain consistency with ThriftResultSet - if isinstance(self.results, JsonQueue): - rows = self.results.next_n_rows(size) - self._next_row_index += len(rows) + rows = self.results.next_n_rows(size) + self._next_row_index += len(rows) - # Convert to Row objects - return self._convert_to_row_objects(rows) - else: - raise NotImplementedError("Unsupported queue type") + # Convert to Row objects + return self._convert_to_row_objects(rows) def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ - # Note: We check for the specific queue type to maintain consistency with ThriftResultSet - if isinstance(self.results, JsonQueue): - rows = self.results.remaining_rows() - self._next_row_index += len(rows) - # Convert to Row objects - return self._convert_to_row_objects(rows) + rows = self.results.remaining_rows() + self._next_row_index += len(rows) + + # Convert to Row objects + return self._convert_to_row_objects(rows) + + def _create_empty_arrow_table(self) -> Any: + """ + Create an empty PyArrow table with the schema from the result set. + + Returns: + An empty PyArrow table with the correct schema. + """ + import pyarrow + + # Try to use schema bytes if available + if self._arrow_schema_bytes: + schema = pyarrow.ipc.read_schema( + pyarrow.BufferReader(self._arrow_schema_bytes) + ) + return pyarrow.Table.from_pydict( + {name: [] for name in schema.names}, schema=schema + ) + + # Fall back to creating schema from description + if self.description: + # Map SQL types to PyArrow types + type_map = { + "boolean": pyarrow.bool_(), + "tinyint": pyarrow.int8(), + "smallint": pyarrow.int16(), + "int": pyarrow.int32(), + "bigint": pyarrow.int64(), + "float": pyarrow.float32(), + "double": pyarrow.float64(), + "string": pyarrow.string(), + "binary": pyarrow.binary(), + "timestamp": pyarrow.timestamp("us"), + "date": pyarrow.date32(), + "decimal": pyarrow.decimal128(38, 18), # Default precision and scale + } + + fields = [] + for col_desc in self.description: + col_name = col_desc[0] + col_type = col_desc[1].lower() if col_desc[1] else "string" + + # Handle decimal with precision and scale + if ( + col_type == "decimal" + and col_desc[4] is not None + and col_desc[5] is not None + ): + arrow_type = pyarrow.decimal128(col_desc[4], col_desc[5]) + else: + arrow_type = type_map.get(col_type, pyarrow.string()) + + fields.append(pyarrow.field(col_name, arrow_type)) + + schema = pyarrow.schema(fields) + return pyarrow.Table.from_pydict( + {name: [] for name in schema.names}, schema=schema + ) + + # If no schema information is available, return an empty table + return pyarrow.Table.from_pydict({}) + + def _convert_rows_to_arrow_table(self, rows: List[Row]) -> Any: + """ + Convert a list of Row objects to a PyArrow table. + + Args: + rows: List of Row objects to convert. + + Returns: + PyArrow table containing the data from the rows. + """ + import pyarrow + + if not rows: + return self._create_empty_arrow_table() + + # Extract column names from description + if self.description: + column_names = [col[0] for col in self.description] else: - raise NotImplementedError("Unsupported queue type") + # If no description, use the attribute names from the first row + column_names = rows[0]._fields + + # Convert rows to columns + columns: dict[str, list] = {name: [] for name in column_names} + + for row in rows: + for i, name in enumerate(column_names): + if hasattr(row, "_asdict"): # If it's a Row object + columns[name].append(row[i]) + else: # If it's a raw list + columns[name].append(row[i]) + + # Create PyArrow table + return pyarrow.Table.from_pydict(columns) def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" if not pyarrow: raise ImportError("PyArrow is required for Arrow support") - if isinstance(self.results, JsonQueue): - rows = self.fetchmany(size) - if not rows: - # Return empty Arrow table with schema - return self._create_empty_arrow_table() + rows = self.fetchmany(size) + if not rows: + # Return empty Arrow table with schema + return self._create_empty_arrow_table() - # Convert rows to Arrow table - return self._convert_rows_to_arrow_table(rows) - else: - raise NotImplementedError("Unsupported queue type") + # Convert rows to Arrow table + return self._convert_rows_to_arrow_table(rows) def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" if not pyarrow: raise ImportError("PyArrow is required for Arrow support") - if isinstance(self.results, JsonQueue): - rows = self.fetchall() - if not rows: - # Return empty Arrow table with schema - return self._create_empty_arrow_table() + rows = self.fetchall() + if not rows: + # Return empty Arrow table with schema + return self._create_empty_arrow_table() - # Convert rows to Arrow table - return self._convert_rows_to_arrow_table(rows) - else: - raise NotImplementedError("Unsupported queue type") + # Convert rows to Arrow table + return self._convert_rows_to_arrow_table(rows) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 0b6f10803..e1c85fb9f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -536,7 +536,7 @@ def test_get_execution_result( print(result) # Verify basic properties of the result - assert result.statement_id == "test-statement-123" + assert result.command_id.to_sea_statement_id() == "test-statement-123" assert result.status == CommandState.SUCCEEDED # Verify the HTTP request diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index d5d8a3667..85ad60501 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -123,10 +123,22 @@ def test_close_when_connection_closed( assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_unimplemented_methods( + @pytest.fixture + def mock_results_queue(self): + """Create a mock results queue.""" + mock_queue = Mock() + mock_queue.next_n_rows.return_value = [["value1", 123], ["value2", 456]] + mock_queue.remaining_rows.return_value = [ + ["value1", 123], + ["value2", 456], + ["value3", 789], + ] + return mock_queue + + def test_fill_results_buffer( self, mock_connection, mock_sea_client, execute_response ): - """Test that unimplemented methods raise NotImplementedError.""" + """Test that _fill_results_buffer returns None.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -135,57 +147,195 @@ def test_unimplemented_methods( arraysize=100, ) - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() + assert result_set._fill_results_buffer() is None - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) + def test_convert_to_row_objects( + self, mock_connection, mock_sea_client, execute_response + ): + """Test converting raw data rows to Row objects.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() + # Test with empty description + result_set.description = None + rows = [["value1", 123], ["value2", 456]] + converted_rows = result_set._convert_to_row_objects(rows) + assert converted_rows == rows - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() + # Test with empty rows + result_set.description = [ + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), + ] + assert result_set._convert_to_row_objects([]) == [] - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) + # Test with description and rows + rows = [["value1", 123], ["value2", 456]] + converted_rows = result_set._convert_to_row_objects(rows) + assert len(converted_rows) == 2 + assert converted_rows[0].col1 == "value1" + assert converted_rows[0].col2 == 123 + assert converted_rows[1].col1 == "value2" + assert converted_rows[1].col2 == 456 - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() + def test_fetchone( + self, mock_connection, mock_sea_client, execute_response, mock_results_queue + ): + """Test fetchone method.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_results_queue + result_set.description = [ + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), + ] - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) + # Mock the next_n_rows to return a single row + mock_results_queue.next_n_rows.return_value = [["value1", 123]] + + row = result_set.fetchone() + assert row is not None + assert row.col1 == "value1" + assert row.col2 == 123 + + # Test when no rows are available + mock_results_queue.next_n_rows.return_value = [] + assert result_set.fetchone() is None + + def test_fetchmany( + self, mock_connection, mock_sea_client, execute_response, mock_results_queue + ): + """Test fetchmany method.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_results_queue + result_set.description = [ + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), + ] + # Test with specific size + rows = result_set.fetchmany(2) + assert len(rows) == 2 + assert rows[0].col1 == "value1" + assert rows[0].col2 == 123 + assert rows[1].col1 == "value2" + assert rows[1].col2 == 456 + + # Test with default size (arraysize) + result_set.arraysize = 2 + mock_results_queue.next_n_rows.reset_mock() + rows = result_set.fetchmany() + mock_results_queue.next_n_rows.assert_called_with(2) + + # Test with negative size with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + ValueError, match="size argument for fetchmany is -1 but must be >= 0" ): - # Test using the result set in a for loop - for row in result_set: - pass + result_set.fetchmany(-1) + + def test_fetchall( + self, mock_connection, mock_sea_client, execute_response, mock_results_queue + ): + """Test fetchall method.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_results_queue + result_set.description = [ + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), + ] + + rows = result_set.fetchall() + assert len(rows) == 3 + assert rows[0].col1 == "value1" + assert rows[0].col2 == 123 + assert rows[1].col1 == "value2" + assert rows[1].col2 == 456 + assert rows[2].col1 == "value3" + assert rows[2].col2 == 789 + + # Verify _next_row_index is updated + assert result_set._next_row_index == 3 + + @pytest.mark.skipif( + pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, + reason="PyArrow is not installed", + ) + def test_create_empty_arrow_table( + self, mock_connection, mock_sea_client, execute_response, monkeypatch + ): + """Test creating an empty Arrow table with schema.""" + import pyarrow - def test_fill_results_buffer_not_implemented( + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Mock _arrow_schema_bytes to return a valid schema + schema = pyarrow.schema( + [ + pyarrow.field("col1", pyarrow.string()), + pyarrow.field("col2", pyarrow.int32()), + ] + ) + schema_bytes = schema.serialize().to_pybytes() + monkeypatch.setattr(result_set, "_arrow_schema_bytes", schema_bytes) + + # Test with schema bytes + empty_table = result_set._create_empty_arrow_table() + assert isinstance(empty_table, pyarrow.Table) + assert empty_table.num_rows == 0 + assert empty_table.num_columns == 2 + assert empty_table.schema.names == ["col1", "col2"] + + # Test without schema bytes but with description + monkeypatch.setattr(result_set, "_arrow_schema_bytes", b"") + result_set.description = [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ] + + empty_table = result_set._create_empty_arrow_table() + assert isinstance(empty_table, pyarrow.Table) + assert empty_table.num_rows == 0 + assert empty_table.num_columns == 2 + assert empty_table.schema.names == ["col1", "col2"] + + @pytest.mark.skipif( + pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, + reason="PyArrow is not installed", + ) + def test_convert_rows_to_arrow_table( self, mock_connection, mock_sea_client, execute_response ): - """Test that _fill_results_buffer raises NotImplementedError.""" + """Test converting rows to Arrow table.""" + import pyarrow + result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -194,8 +344,137 @@ def test_fill_results_buffer_not_implemented( arraysize=100, ) - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", - ): - result_set._fill_results_buffer() + result_set.description = [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ] + + rows = [["value1", 123], ["value2", 456], ["value3", 789]] + + arrow_table = result_set._convert_rows_to_arrow_table(rows) + assert isinstance(arrow_table, pyarrow.Table) + assert arrow_table.num_rows == 3 + assert arrow_table.num_columns == 2 + assert arrow_table.schema.names == ["col1", "col2"] + + # Check data + assert arrow_table.column(0).to_pylist() == ["value1", "value2", "value3"] + assert arrow_table.column(1).to_pylist() == [123, 456, 789] + + @pytest.mark.skipif( + pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, + reason="PyArrow is not installed", + ) + def test_fetchmany_arrow( + self, mock_connection, mock_sea_client, execute_response, mock_results_queue + ): + """Test fetchmany_arrow method.""" + import pyarrow + + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_results_queue + result_set.description = [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ] + + # Test with data + arrow_table = result_set.fetchmany_arrow(2) + assert isinstance(arrow_table, pyarrow.Table) + assert arrow_table.num_rows == 2 + assert arrow_table.column(0).to_pylist() == ["value1", "value2"] + assert arrow_table.column(1).to_pylist() == [123, 456] + + # Test with no data + mock_results_queue.next_n_rows.return_value = [] + + # Mock _create_empty_arrow_table to return an empty table + result_set._create_empty_arrow_table = Mock() + empty_table = pyarrow.Table.from_pydict({"col1": [], "col2": []}) + result_set._create_empty_arrow_table.return_value = empty_table + + arrow_table = result_set.fetchmany_arrow(2) + assert isinstance(arrow_table, pyarrow.Table) + assert arrow_table.num_rows == 0 + result_set._create_empty_arrow_table.assert_called_once() + + @pytest.mark.skipif( + pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, + reason="PyArrow is not installed", + ) + def test_fetchall_arrow( + self, mock_connection, mock_sea_client, execute_response, mock_results_queue + ): + """Test fetchall_arrow method.""" + import pyarrow + + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_results_queue + result_set.description = [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ] + + # Test with data + arrow_table = result_set.fetchall_arrow() + assert isinstance(arrow_table, pyarrow.Table) + assert arrow_table.num_rows == 3 + assert arrow_table.column(0).to_pylist() == ["value1", "value2", "value3"] + assert arrow_table.column(1).to_pylist() == [123, 456, 789] + + # Test with no data + mock_results_queue.remaining_rows.return_value = [] + + # Mock _create_empty_arrow_table to return an empty table + result_set._create_empty_arrow_table = Mock() + empty_table = pyarrow.Table.from_pydict({"col1": [], "col2": []}) + result_set._create_empty_arrow_table.return_value = empty_table + + arrow_table = result_set.fetchall_arrow() + assert isinstance(arrow_table, pyarrow.Table) + assert arrow_table.num_rows == 0 + result_set._create_empty_arrow_table.assert_called_once() + + def test_iteration_protocol( + self, mock_connection, mock_sea_client, execute_response, mock_results_queue + ): + """Test iteration protocol using fetchone.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_results_queue + result_set.description = [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ] + + # Set up mock to return different values on each call + mock_results_queue.next_n_rows.side_effect = [ + [["value1", 123]], + [["value2", 456]], + [], # End of data + ] + + # Test iteration + rows = list(result_set) + assert len(rows) == 2 + assert rows[0].col1 == "value1" + assert rows[0].col2 == 123 + assert rows[1].col1 == "value2" + assert rows[1].col2 == 456 diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index a05e8cb87..ca77348f4 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -610,7 +610,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): self.assertIn("some information about the error", str(cm.exception)) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) def test_handle_execute_response_sets_compression_in_direct_results( self, build_queue @@ -998,7 +999,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( @@ -1043,7 +1045,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(has_more_rows, has_more_rows_result) @patch( - "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + "databricks.sql.utils.ThriftResultSetQueueFactory.build_queue", + return_value=Mock(), ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( From e96e5b8c950c2b7613333b5d20da537e9f3e6ceb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 08:37:06 +0000 Subject: [PATCH 078/204] large query results Signed-off-by: varun-edachali-dbx --- .../tests/test_sea_async_query.py | 20 +++++++++++------- .../experimental/tests/test_sea_sync_query.py | 21 ++++++++++--------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a776377c3..35135b64a 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -51,12 +51,12 @@ def test_sea_async_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # Execute a query that returns 100 rows asynchronously cursor = connection.cursor() - logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + logger.info("Executing asynchronous query with cloud fetch: SELECT 100 rows") + cursor.execute_async( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute_async("SELECT 1 as test_value") logger.info( "Asynchronous query submitted successfully with cloud fetch enabled" ) @@ -69,6 +69,8 @@ def test_sea_async_query_with_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + rows = cursor.fetchall() + logger.info(f"Retrieved rows: {rows}") logger.info( "Successfully retrieved asynchronous query results with cloud fetch enabled" ) @@ -130,12 +132,12 @@ def test_sea_async_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query asynchronously + # Execute a query that returns 100 rows asynchronously cursor = connection.cursor() - logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + logger.info("Executing asynchronous query without cloud fetch: SELECT 100 rows") + cursor.execute_async( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute_async("SELECT 1 as test_value") logger.info( "Asynchronous query submitted successfully with cloud fetch disabled" ) @@ -148,6 +150,8 @@ def test_sea_async_query_without_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() + rows = cursor.fetchall() + logger.info(f"Retrieved rows: {rows}") logger.info( "Successfully retrieved asynchronous query results with cloud fetch disabled" ) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index f44246fad..0f12445d1 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -49,13 +49,14 @@ def test_sea_sync_query_with_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # Execute a query that returns 100 rows cursor = connection.cursor() - logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + logger.info("Executing synchronous query with cloud fetch: SELECT 100 rows") + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") + rows = cursor.fetchall() + logger.info(f"Retrieved rows: {rows}") # Close resources cursor.close() @@ -114,16 +115,16 @@ def test_sea_sync_query_without_cloud_fetch(): f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - # Execute a simple query + # Execute a query that returns 100 rows cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows") + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - cursor.execute("SELECT 1 as test_value") logger.info("Query executed successfully with cloud fetch disabled") rows = cursor.fetchall() - logger.info(f"Rows: {rows}") + logger.info(f"Retrieved rows: {rows}") # Close resources cursor.close() From 165c4f35ce69f282b03e6522c6ea72c6d0a8f5fc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:18:39 +0000 Subject: [PATCH 079/204] remove un-necessary changes covered by #588 Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 11 +- src/databricks/sql/result_set.py | 73 ------- tests/unit/test_sea_result_set.py | 200 ------------------- tests/unit/test_thrift_backend.py | 32 +-- 4 files changed, 7 insertions(+), 309 deletions(-) delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d28a2c6fd..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -15,6 +15,7 @@ CommandId, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id try: @@ -841,8 +842,6 @@ def get_execution_result( status = self.get_query_state(command_id) - status = self.get_query_state(command_id) - execute_response = ExecuteResponse( command_id=command_id, status=status, @@ -895,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1189,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 97b10cbbe..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -438,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index 02421a915..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() \ No newline at end of file diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 88adcd3e9..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,13 +619,6 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), @@ -927,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -957,12 +948,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) @@ -977,7 +962,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,12 +982,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req tcli_service_instance.GetOperationStatus.return_value = ( ttypes.TGetOperationStatusResp( @@ -1694,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2256,8 +2233,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class, mock_result_set From a6e40d0dce9acd43c29e2de76f7d64ce96f775a8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:25:51 +0000 Subject: [PATCH 080/204] simplify test module Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 41 +++++++++------------ 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..3a8b163f5 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,20 +1,18 @@ """ Main script to run all SEA connector tests. -This script imports and runs all the individual test modules and displays +This script runs all the individual test modules and displays a summary of test results with visual indicators. """ import os import sys import logging -import importlib.util -from typing import Dict, Callable, List, Tuple +import subprocess +from typing import List, Tuple -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Define test modules and their main test functions TEST_MODULES = [ "test_sea_session", "test_sea_sync_query", @@ -23,29 +21,27 @@ ] -def load_test_function(module_name: str) -> Callable: - """Load a test function from a module.""" +def run_test_module(module_name: str) -> bool: + """Run a test module and return success status.""" module_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + # Simply run the module as a script - each module handles its own test execution + result = subprocess.run( + [sys.executable, module_path], capture_output=True, text=True + ) - # Get the main test function (assuming it starts with "test_") - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - # For sync and async query modules, we want the main function that runs both tests - if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": - return getattr(module, name) + # Log the output from the test module + if result.stdout: + for line in result.stdout.strip().split("\n"): + logger.info(line) - # Fallback to the first test function found - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - return getattr(module, name) + if result.stderr: + for line in result.stderr.strip().split("\n"): + logger.error(line) - raise ValueError(f"No test function found in module {module_name}") + return result.returncode == 0 def run_tests() -> List[Tuple[str, bool]]: @@ -54,12 +50,11 @@ def run_tests() -> List[Tuple[str, bool]]: for module_name in TEST_MODULES: try: - test_func = load_test_function(module_name) logger.info(f"\n{'=' * 50}") logger.info(f"Running test: {module_name}") logger.info(f"{'-' * 50}") - success = test_func() + success = run_test_module(module_name) results.append((module_name, success)) status = "✅ PASSED" if success else "❌ FAILED" From 52e3088b31d659064e740388bd2f25df1c3b158f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:26:23 +0000 Subject: [PATCH 081/204] logging -> debug level Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 3a8b163f5..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,7 +10,7 @@ import subprocess from typing import List, Tuple -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) TEST_MODULES = [ From 641c09b0d2a5fb5c79b3b696f767f81d0b5283e4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:28:18 +0000 Subject: [PATCH 082/204] change table name in log Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py index 394c48b24..a200d97d3 100644 --- a/examples/experimental/tests/test_sea_metadata.py +++ b/examples/experimental/tests/test_sea_metadata.py @@ -71,7 +71,7 @@ def test_sea_metadata(): # Test columns for a specific table # Using a common table that should exist in most environments logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." ) cursor.columns( catalog_name=catalog, schema_name="default", table_name="customer" From ffded6ee2c50eb2efc1cdd2e580d51e396ce2cdd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:39:37 +0000 Subject: [PATCH 083/204] remove un-necessary changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 168 +++---- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 191 -------- .../experimental/tests/test_sea_metadata.py | 98 ---- .../experimental/tests/test_sea_session.py | 71 --- .../experimental/tests/test_sea_sync_query.py | 161 ------- tests/unit/test_sea_backend.py | 453 ++++-------------- tests/unit/test_sea_result_set.py | 200 -------- tests/unit/test_thrift_backend.py | 32 +- 9 files changed, 155 insertions(+), 1219 deletions(-) delete mode 100644 examples/experimental/tests/__init__.py delete mode 100644 examples/experimental/tests/test_sea_async_query.py delete mode 100644 examples/experimental/tests/test_sea_metadata.py delete mode 100644 examples/experimental/tests/test_sea_session.py delete mode 100644 examples/experimental/tests/test_sea_sync_query.py delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index b03f8ff64..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,120 +1,66 @@ -""" -Main script to run all SEA connector tests. - -This script imports and runs all the individual test modules and displays -a summary of test results with visual indicators. -""" import os import sys import logging -import importlib.util -from typing import Dict, Callable, List, Tuple +from databricks.sql.client import Connection -# Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -# Define test modules and their main test functions -TEST_MODULES = [ - "test_sea_session", - "test_sea_sync_query", - "test_sea_async_query", - "test_sea_metadata", -] - - -def load_test_function(module_name: str) -> Callable: - """Load a test function from a module.""" - module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" - ) - - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Get the main test function (assuming it starts with "test_") - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - # For sync and async query modules, we want the main function that runs both tests - if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec": - return getattr(module, name) - - # Fallback to the first test function found - for name in dir(module): - if name.startswith("test_") and callable(getattr(module, name)): - return getattr(module, name) - - raise ValueError(f"No test function found in module {module_name}") - - -def run_tests() -> List[Tuple[str, bool]]: - """Run all tests and return results.""" - results = [] - - for module_name in TEST_MODULES: - try: - test_func = load_test_function(module_name) - logger.info(f"\n{'=' * 50}") - logger.info(f"Running test: {module_name}") - logger.info(f"{'-' * 50}") - - success = test_func() - results.append((module_name, success)) - - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"Test {module_name}: {status}") - - except Exception as e: - logger.error(f"Error loading or running test {module_name}: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - results.append((module_name, False)) - - return results - - -def print_summary(results: List[Tuple[str, bool]]) -> None: - """Print a summary of test results.""" - logger.info(f"\n{'=' * 50}") - logger.info("TEST SUMMARY") - logger.info(f"{'-' * 50}") - - passed = sum(1 for _, success in results if success) - total = len(results) - - for module_name, success in results: - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"{status} - {module_name}") - - logger.info(f"{'-' * 50}") - logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") - logger.info(f"{'=' * 50}") - - -if __name__ == "__main__": - # Check if required environment variables are set - required_vars = [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_HTTP_PATH", - "DATABRICKS_TOKEN", - ] - missing_vars = [var for var in required_vars if not os.environ.get(var)] - - if missing_vars: - logger.error( - f"Missing required environment variables: {', '.join(missing_vars)}" +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) - logger.error("Please set these variables before running the tests.") + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) sys.exit(1) + + logger.info("SEA session test completed successfully") - # Run all tests - results = run_tests() - - # Print summary - print_summary(results) - - # Exit with appropriate status code - all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) +if __name__ == "__main__": + test_sea_session() diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py deleted file mode 100644 index a776377c3..000000000 --- a/examples/experimental/tests/test_sea_async_query.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Test for SEA asynchronous query execution functionality. -""" -import os -import sys -import logging -import time -from databricks.sql.client import Connection -from databricks.sql.backend.types import CommandState - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_async_query_with_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch enabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_without_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch disabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_exec(): - """ - Run both asynchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info( - f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info( - f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_async_query_exec() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py deleted file mode 100644 index c715e5984..000000000 --- a/examples/experimental/tests/test_sea_metadata.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -Test for SEA metadata functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_metadata(): - """ - Test metadata operations using the SEA backend. - - This function connects to a Databricks SQL endpoint using the SEA backend, - and executes metadata operations like catalogs(), schemas(), tables(), and columns(). - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - if not catalog: - logger.error( - "DATABRICKS_CATALOG environment variable is required for metadata tests." - ) - return False - - try: - # Create connection - logger.info("Creating connection for metadata operations") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Test catalogs - cursor = connection.cursor() - logger.info("Fetching catalogs...") - cursor.catalogs() - logger.info("Successfully fetched catalogs") - - # Test schemas - logger.info(f"Fetching schemas for catalog '{catalog}'...") - cursor.schemas(catalog_name=catalog) - logger.info("Successfully fetched schemas") - - # Test tables - logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") - cursor.tables(catalog_name=catalog, schema_name="default") - logger.info("Successfully fetched tables") - - # Test columns for a specific table - # Using a common table that should exist in most environments - logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'information_schema'..." - ) - cursor.columns( - catalog_name=catalog, schema_name="default", table_name="information_schema" - ) - logger.info("Successfully fetched columns") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error during SEA metadata test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_metadata() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py deleted file mode 100644 index 516c1bbb8..000000000 --- a/examples/experimental/tests/test_sea_session.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Test for SEA session management functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"Backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_session() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py deleted file mode 100644 index 07be8aafc..000000000 --- a/examples/experimental/tests/test_sea_sync_query.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Test for SEA synchronous query execution functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_sync_query_with_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_without_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_exec(): - """ - Run both synchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info( - f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info( - f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2fa362b8e..bc2688a68 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,20 +1,11 @@ -""" -Tests for the SEA (Statement Execution API) backend implementation. - -This module contains tests for the SeaDatabricksClient class, which implements -the Databricks SQL connector's SEA backend functionality. -""" - -import json import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error class TestSeaBackend: @@ -50,23 +41,6 @@ def sea_client(self, mock_http_client): return client - @pytest.fixture - def sea_session_id(self): - """Create a SEA session ID.""" - return SessionId.from_sea_session_id("test-session-123") - - @pytest.fixture - def sea_command_id(self): - """Create a SEA command ID.""" - return CommandId.from_sea_statement_id("test-statement-123") - - @pytest.fixture - def mock_cursor(self): - """Create a mock cursor.""" - cursor = Mock() - cursor.active_command_id = None - return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -201,348 +175,109 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command synchronously.""" - # Set up mock responses - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "schema": [ - { - "name": "col1", - "type_name": "STRING", - "type_text": "string", - "nullable": True, - } - ], - "total_row_count": 1, - "total_byte_count": 100, - }, - "result": {"data": [["value1"]]}, - } - mock_http_client._make_request.return_value = execute_response - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" - - def test_execute_command_async( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command asynchronously.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-456", - "status": {"state": "PENDING"}, - } - mock_http_client._make_request.return_value = execute_response - - # Call the method - result = sea_client.execute_command( - operation="SELECT 1", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=True, # Async mode - enforce_embedded_schema_correctness=False, + def test_session_configuration_helpers(self): + """Test the session configuration helper methods.""" + # Test getting default value for a supported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "ANSI_MODE" ) + assert default_value == "true" - # Verify the result is None for async operation - assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") - assert isinstance(mock_cursor.active_command_id, CommandId) - assert mock_cursor.active_command_id.guid == "test-statement-456" - - def test_execute_command_with_polling( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling - initial_response = { - "statement_id": "test-statement-789", - "status": {"state": "RUNNING"}, - } - poll_response = { - "statement_id": "test-statement-789", - "status": {"state": "SUCCEEDED"}, - "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, - "result": {"data": []}, - } - - # Configure mock to return different responses on subsequent calls - mock_http_client._make_request.side_effect = [initial_response, poll_response] - - # Mock the get_execution_result method - with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" - ) as mock_get_result: - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method - result = sea_client.execute_command( - operation="SELECT * FROM large_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - # Verify the result - assert result == "mock_result_set" - - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response - execute_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, + # Test getting default value for an unsupported parameter + default_value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert default_value is None + + # Test getting the list of allowed configurations + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", } - mock_http_client._make_request.return_value = execute_response + assert set(allowed_configs) == expected_keys - # Create parameter mock - param = MagicMock() - param.name = "param1" - param.value = "value1" - param.type = "STRING" + def test_unimplemented_methods(self, sea_client): + """Test that unimplemented methods raise NotImplementedError.""" + # Create dummy parameters for testing + session_id = SessionId.from_sea_session_id("test-session") + command_id = MagicMock() + cursor = MagicMock() - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + # Test execute_command + with pytest.raises(NotImplementedError) as excinfo: sea_client.execute_command( - operation="SELECT * FROM table WHERE col = :param1", - session_id=sea_session_id, + operation="SELECT 1", + session_id=session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=mock_cursor, + cursor=cursor, use_cloud_fetch=False, - parameters=[param], + parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert "parameters" in kwargs["data"] - assert len(kwargs["data"]["parameters"]) == 1 - assert kwargs["data"]["parameters"][0]["name"] == "param1" - assert kwargs["data"]["parameters"][0]["value"] == "value1" - assert kwargs["data"]["parameters"][0]["type"] == "STRING" - - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution - error_response = { - "statement_id": "test-statement-123", - "status": { - "state": "FAILED", - "error": { - "message": "Syntax error in SQL", - "error_code": "SYNTAX_ERROR", - }, - }, - } - - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays - with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method - sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response - mock_http_client._make_request.return_value = { - "statement_id": "test-statement-123", - "status": {"state": "RUNNING"}, - } - - # Call the method - state = sea_client.get_query_state(sea_command_id) - - # Verify the result - assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + assert "execute_command is not yet implemented" in str(excinfo.value) + + # Test cancel_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.cancel_command(command_id) + assert "cancel_command is not yet implemented" in str(excinfo.value) + + # Test close_command + with pytest.raises(NotImplementedError) as excinfo: + sea_client.close_command(command_id) + assert "close_command is not yet implemented" in str(excinfo.value) + + # Test get_query_state + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_query_state(command_id) + assert "get_query_state is not yet implemented" in str(excinfo.value) + + # Test get_execution_result + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_execution_result(command_id, cursor) + assert "get_execution_result is not yet implemented" in str(excinfo.value) + + # Test metadata operations + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(session_id, 100, 1000, cursor) + assert "get_catalogs is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(session_id, 100, 1000, cursor) + assert "get_schemas is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(session_id, 100, 1000, cursor) + assert "get_tables is not yet implemented" in str(excinfo.value) + + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(session_id, 100, 1000, cursor) + assert "get_columns is not yet implemented" in str(excinfo.value) + + def test_max_download_threads_property(self, sea_client): + """Test the max_download_threads property.""" + assert sea_client.max_download_threads == 10 + + # Create a client with a custom value + custom_client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=20, ) - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response - sea_response = { - "statement_id": "test-statement-123", - "status": {"state": "SUCCEEDED"}, - "manifest": { - "format": "JSON_ARRAY", - "schema": { - "column_count": 1, - "columns": [ - { - "name": "test_value", - "type_text": "INT", - "type_name": "INT", - "position": 0, - } - ], - }, - "total_chunk_count": 1, - "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], - "total_row_count": 1, - "truncated": False, - }, - "result": { - "chunk_index": 0, - "row_offset": 0, - "row_count": 1, - "data_array": [["1"]], - }, - } - mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation - result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) - - # Verify basic properties of the result - assert result.statement_id == "test-statement-123" - assert result.status == CommandState.SUCCEEDED - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) + # Verify the custom value is returned + assert custom_client.max_download_threads == 20 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index b691872af..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 88adcd3e9..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,13 +619,6 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), @@ -927,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -957,12 +948,6 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) @@ -977,7 +962,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,12 +982,6 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req tcli_service_instance.GetOperationStatus.return_value = ( ttypes.TGetOperationStatusResp( @@ -1694,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2256,8 +2233,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class, mock_result_set From 227f6b36bd65cc8a7c903316334a18a8a8e249b1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 13:41:29 +0000 Subject: [PATCH 084/204] remove un-necessary backend cahnges Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 481 ++----------------- src/databricks/sql/backend/thrift_backend.py | 11 +- src/databricks/sql/result_set.py | 73 --- 3 files changed, 42 insertions(+), 523 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b1ad7cf76..97d25a058 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,44 +1,23 @@ import logging -import time import re -from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set - -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) +from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet - from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import ( - SessionId, - CommandId, - CommandState, - BackendType, - ExecuteResponse, -) -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.types import SSLOptions -from databricks.sql.backend.sea.models.base import ( - ResultData, - ExternalLink, - ResultManifest, +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, ) +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( - ExecuteStatementRequest, - GetStatementRequest, - CancelStatementRequest, - CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, - StatementParameter, - ExecuteStatementResponse, - GetStatementResponse, CreateSessionResponse, ) @@ -76,9 +55,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -88,8 +64,6 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" - CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" - CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -122,7 +96,6 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) - self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -279,19 +252,6 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) - @staticmethod - def is_session_configuration_parameter_supported(name: str) -> bool: - """ - Check if a session configuration parameter is supported. - - Args: - name: The name of the session configuration parameter - - Returns: - True if the parameter is supported, False otherwise - """ - return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP - @staticmethod def get_allowed_session_configurations() -> List[str]: """ @@ -302,182 +262,8 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - def _get_schema_bytes(self, sea_response) -> Optional[bytes]: - """ - Extract schema bytes from the SEA response. - - For ARROW format, we need to get the schema bytes from the first chunk. - If the first chunk is not available, we need to get it from the server. - - Args: - sea_response: The response from the SEA API - - Returns: - bytes: The schema bytes or None if not available - """ - import requests - import lz4.frame - - # Check if we have the first chunk in the response - result_data = sea_response.get("result", {}) - external_links = result_data.get("external_links", []) - - if not external_links: - return None - - # Find the first chunk (chunk_index = 0) - first_chunk = None - for link in external_links: - if link.get("chunk_index") == 0: - first_chunk = link - break - - if not first_chunk: - # Try to fetch the first chunk from the server - statement_id = sea_response.get("statement_id") - if not statement_id: - return None - - chunks_response = self.get_chunk_links(statement_id, 0) - if not chunks_response.external_links: - return None - - first_chunk = chunks_response.external_links[0].__dict__ - - # Download the first chunk to get the schema bytes - external_link = first_chunk.get("external_link") - http_headers = first_chunk.get("http_headers", {}) - - if not external_link: - return None - - # Use requests to download the first chunk - http_response = requests.get( - external_link, - headers=http_headers, - verify=self.ssl_options.tls_verify, - ) - - if http_response.status_code != 200: - raise Error(f"Failed to download schema bytes: {http_response.text}") - - # Extract schema bytes from the Arrow file - # The schema is at the beginning of the file - data = http_response.content - if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": - data = lz4.frame.decompress(data) - - # Return the schema bytes - return data - - def _results_message_to_execute_response(self, sea_response, command_id): - """ - Convert a SEA response to an ExecuteResponse and extract result data. - - Args: - sea_response: The response from the SEA API - command_id: The command ID - - Returns: - tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, - result data object, and manifest object - """ - # Extract status - status_data = sea_response.get("status", {}) - state = CommandState.from_sea_state(status_data.get("state", "")) - - # Extract description from manifest - description = None - manifest_data = sea_response.get("manifest", {}) - schema_data = manifest_data.get("schema", {}) - columns_data = schema_data.get("columns", []) - - if columns_data: - columns = [] - for col_data in columns_data: - if not isinstance(col_data, dict): - continue - - # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) - columns.append( - ( - col_data.get("name", ""), # name - col_data.get("type_name", ""), # type_code - None, # display_size (not provided by SEA) - None, # internal_size (not provided by SEA) - col_data.get("precision"), # precision - col_data.get("scale"), # scale - col_data.get("nullable", True), # null_ok - ) - ) - description = columns if columns else None - - # Extract schema bytes for Arrow format - schema_bytes = None - format = manifest_data.get("format") - if format == "ARROW_STREAM": - # For ARROW format, we need to get the schema bytes - schema_bytes = self._get_schema_bytes(sea_response) - - # Check for compression - lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" - - # Initialize result_data_obj and manifest_obj - result_data_obj = None - manifest_obj = None - - result_data = sea_response.get("result", {}) - if result_data: - # Convert external links - external_links = None - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers", {}), - ) - ) - - # Create the result data object - result_data_obj = ResultData( - data=result_data.get("data_array"), external_links=external_links - ) - - # Create the manifest object - manifest_obj = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - execute_response = ExecuteResponse( - command_id=command_id, - status=state, - description=description, - has_been_closed_server_side=False, - lz4_compressed=lz4_compressed, - is_staging_operation=False, - arrow_schema_bytes=schema_bytes, - result_format=manifest_data.get("format"), - ) - - return execute_response, result_data_obj, manifest_obj + # == Not Implemented Operations == + # These methods will be implemented in future iterations def execute_command( self, @@ -488,230 +274,41 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List, + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: - """ - Execute a SQL command using the SEA backend. - - Args: - operation: SQL command to execute - session_id: Session identifier - max_rows: Maximum number of rows to fetch - max_bytes: Maximum number of bytes to fetch - lz4_compression: Whether to use LZ4 compression - cursor: Cursor executing the command - use_cloud_fetch: Whether to use cloud fetch - parameters: SQL parameters - async_op: Whether to execute asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - ResultSet: A SeaResultSet instance for the executed command - """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") - - sea_session_id = session_id.to_sea_session_id() - - # Convert parameters to StatementParameter objects - sea_parameters = [] - if parameters: - for param in parameters: - sea_parameters.append( - StatementParameter( - name=param.name, - value=param.value, - type=param.type if hasattr(param, "type") else None, - ) - ) - - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else None - - request = ExecuteStatementRequest( - warehouse_id=self.warehouse_id, - session_id=sea_session_id, - statement=operation, - disposition=disposition, - format=format, - wait_timeout="0s" if async_op else "10s", - on_wait_timeout="CONTINUE", - row_limit=max_rows, - parameters=sea_parameters if sea_parameters else None, - result_compression=result_compression, + ): + """Not implemented yet.""" + raise NotImplementedError( + "execute_command is not yet implemented for SEA backend" ) - response_data = self.http_client._make_request( - method="POST", path=self.STATEMENT_PATH, data=request.to_dict() - ) - response = ExecuteStatementResponse.from_dict(response_data) - statement_id = response.statement_id - if not statement_id: - raise ServerOperationError( - "Failed to execute command: No statement ID returned", - { - "operation-id": None, - "diagnostic-info": None, - }, - ) - - command_id = CommandId.from_sea_statement_id(statement_id) - - # Store the command ID in the cursor - cursor.active_command_id = command_id - - # If async operation, return None and let the client poll for results - if async_op: - return None - - # For synchronous operation, wait for the statement to complete - # Poll until the statement is done - status = response.status - state = status.state - - # Keep polling until we reach a terminal state - while state in [CommandState.PENDING, CommandState.RUNNING]: - time.sleep(0.5) # add a small delay to avoid excessive API calls - state = self.get_query_state(command_id) - - if state != CommandState.SUCCEEDED: - raise ServerOperationError( - f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", - { - "operation-id": command_id.to_sea_statement_id(), - "diagnostic-info": None, - }, - ) - - return self.get_execution_result(command_id, cursor) - def cancel_command(self, command_id: CommandId) -> None: - """ - Cancel a running command. - - Args: - command_id: Command identifier to cancel - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CancelStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="POST", - path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "cancel_command is not yet implemented for SEA backend" ) def close_command(self, command_id: CommandId) -> None: - """ - Close a command and release resources. - - Args: - command_id: Command identifier to close - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = CloseStatementRequest(statement_id=sea_statement_id) - self.http_client._make_request( - method="DELETE", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "close_command is not yet implemented for SEA backend" ) def get_query_state(self, command_id: CommandId) -> CommandState: - """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - request = GetStatementRequest(statement_id=sea_statement_id) - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), + """Not implemented yet.""" + raise NotImplementedError( + "get_query_state is not yet implemented for SEA backend" ) - # Parse the response - response = GetStatementResponse.from_dict(response_data) - return response.status.state - def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ) -> "ResultSet": - """ - Get the result of a command execution. - - Args: - command_id: Command identifier - cursor: Cursor executing the command - - Returns: - ResultSet: A SeaResultSet instance with the execution results - - Raises: - ValueError: If the command ID is invalid - """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self.http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - - # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet - - # Convert the response to an ExecuteResponse and extract result data - ( - execute_response, - result_data, - manifest, - ) = self._results_message_to_execute_response(response_data, command_id) - - return SeaResultSet( - connection=cursor.connection, - execute_response=execute_response, - sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, - result_data=result_data, - manifest=manifest, + ): + """Not implemented yet.""" + raise NotImplementedError( + "get_execution_result is not yet implemented for SEA backend" ) # == Metadata Operations == @@ -722,9 +319,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - raise NotImplementedError("get_catalogs is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -734,9 +331,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - raise NotImplementedError("get_schemas is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -748,9 +345,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_tables is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -762,6 +359,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_columns is not implemented for SEA backend") + ): + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d28a2c6fd..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -15,6 +15,7 @@ CommandId, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id try: @@ -841,8 +842,6 @@ def get_execution_result( status = self.get_query_state(command_id) - status = self.get_query_state(command_id) - execute_response = ExecuteResponse( command_id=command_id, status=status, @@ -895,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1189,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 97b10cbbe..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -438,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") From 68657a3ba20080dde478b3e9d4b0940bdf4ca299 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 14:52:28 +0000 Subject: [PATCH 085/204] remove un-needed GetChunksResponse Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 1 - .../sql/backend/sea/models/responses.py | 35 ------------------- 2 files changed, 36 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b1ad7cf76..6d627162d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor from databricks.sql.result_set import ResultSet - from databricks.sql.backend.sea.models.responses import GetChunksResponse from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d684a9c67..1f73df409 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -196,38 +196,3 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) - - -@dataclass -class GetChunksResponse: - """Response from getting chunks for a statement.""" - - statement_id: str - external_links: List[ExternalLink] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": - """Create a GetChunksResponse from a dictionary.""" - external_links = [] - if "external_links" in data: - for link_data in data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - return cls( - statement_id=data.get("statement_id", ""), - external_links=external_links, - ) From 3940eecd0671deee86ef9b81a1853fcedaf31bb1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 14:53:15 +0000 Subject: [PATCH 086/204] remove un-needed GetChunksResponse only relevant in Fetch phase Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/responses.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index d71262d1d..51f0d4452 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -196,38 +196,3 @@ class CreateSessionResponse: def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse": """Create a CreateSessionResponse from a dictionary.""" return cls(session_id=data.get("session_id", "")) - - -@dataclass -class GetChunksResponse: - """Response from getting chunks for a statement.""" - - statement_id: str - external_links: List[ExternalLink] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse": - """Create a GetChunksResponse from a dictionary.""" - external_links = [] - if "external_links" in data: - for link_data in data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - return cls( - statement_id=data.get("statement_id", ""), - external_links=external_links, - ) From 37813ba6d1fe06d7f9f10d510a059b88dc552496 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:00:35 +0000 Subject: [PATCH 087/204] reduce code duplication in response parsing Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/responses.py | 219 +++++++----------- 1 file changed, 78 insertions(+), 141 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1f73df409..c16f19da3 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,8 +4,8 @@ These models define the structures used in SEA API responses. """ -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field +from typing import Dict, Any +from dataclasses import dataclass from databricks.sql.backend.types import CommandState from databricks.sql.backend.sea.models.base import ( @@ -14,91 +14,92 @@ ResultData, ServiceError, ExternalLink, - ColumnInfo, ) +def _parse_status(data: Dict[str, Any]) -> StatementStatus: + """Parse status from response data.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + return StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: + """Parse manifest from response data.""" + + manifest_data = data.get("manifest", {}) + return ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + +def _parse_result(data: Dict[str, Any]) -> ResultData: + """Parse result data from response data.""" + result_data = data.get("result", {}) + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get("next_chunk_internal_link"), + http_headers=link_data.get("http_headers"), + ) + ) + + return ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + + @dataclass class ExecuteStatementResponse: """Response from executing a SQL statement.""" statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @@ -108,81 +109,17 @@ class GetStatementResponse: statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) From 267c9f44e55778af748749336c26bb06ce0ab33c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:01:29 +0000 Subject: [PATCH 088/204] reduce code duplication Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/responses.py | 221 +++++++----------- 1 file changed, 79 insertions(+), 142 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 51f0d4452..c16f19da3 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,8 +4,8 @@ These models define the structures used in SEA API responses. """ -from typing import Dict, List, Any, Optional, Union -from dataclasses import dataclass, field +from typing import Dict, Any +from dataclasses import dataclass from databricks.sql.backend.types import CommandState from databricks.sql.backend.sea.models.base import ( @@ -14,91 +14,92 @@ ResultData, ServiceError, ExternalLink, - ColumnInfo, ) +def _parse_status(data: Dict[str, Any]) -> StatementStatus: + """Parse status from response data.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + return StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: + """Parse manifest from response data.""" + + manifest_data = data.get("manifest", {}) + return ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + +def _parse_result(data: Dict[str, Any]) -> ResultData: + """Parse result data from response data.""" + result_data = data.get("result", {}) + external_links = None + + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get("next_chunk_internal_link"), + http_headers=link_data.get("http_headers"), + ) + ) + + return ResultData( + data=result_data.get("data_array"), + external_links=external_links, + ) + + @dataclass class ExecuteStatementResponse: """Response from executing a SQL statement.""" statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @@ -108,87 +109,23 @@ class GetStatementResponse: statement_id: str status: StatementStatus - manifest: Optional[ResultManifest] = None - result: Optional[ResultData] = None + manifest: ResultManifest + result: ResultData @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" - status_data = data.get("status", {}) - error = None - if "error" in status_data: - error_data = status_data["error"] - error = ServiceError( - message=error_data.get("message", ""), - error_code=error_data.get("error_code"), - ) - - state = CommandState.from_sea_state(status_data.get("state", "")) - if state is None: - raise ValueError(f"Invalid state: {status_data.get('state', '')}") - - status = StatementStatus( - state=state, - error=error, - sql_state=status_data.get("sql_state"), - ) - - # Parse manifest - manifest = None - if "manifest" in data: - manifest_data = data["manifest"] - manifest = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) - - # Parse result data - result = None - if "result" in data: - result_data = data["result"] - external_links = None - - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers"), - ) - ) - - result = ResultData( - data=result_data.get("data_array"), - external_links=external_links, - ) - return cls( statement_id=data.get("statement_id", ""), - status=status, - manifest=manifest, - result=result, + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @dataclass class CreateSessionResponse: - """Representation of the response from creating a new session.""" + """Response from creating a new session.""" session_id: str From 296711946a5dd735a655961984641ed2a19d0f2a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:03:07 +0000 Subject: [PATCH 089/204] more clear docstrings Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/requests.py | 10 +++++----- src/databricks/sql/backend/sea/models/responses.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index d9483e51a..4c5071dba 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -10,7 +10,7 @@ @dataclass class StatementParameter: - """Parameter for a SQL statement.""" + """Representation of a parameter for a SQL statement.""" name: str value: Optional[str] = None @@ -19,7 +19,7 @@ class StatementParameter: @dataclass class ExecuteStatementRequest: - """Request to execute a SQL statement.""" + """Representation of a request to execute a SQL statement.""" session_id: str statement: str @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class GetStatementRequest: - """Request to get information about a statement.""" + """Representation of a request to get information about a statement.""" statement_id: str @@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CancelStatementRequest: - """Request to cancel a statement.""" + """Representation of a request to cancel a statement.""" statement_id: str @@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CloseStatementRequest: - """Request to close a statement.""" + """Representation of a request to close a statement.""" statement_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index c16f19da3..a8cf0c998 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -85,7 +85,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: @dataclass class ExecuteStatementResponse: - """Response from executing a SQL statement.""" + """Representation of the response from executing a SQL statement.""" statement_id: str status: StatementStatus @@ -105,7 +105,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": @dataclass class GetStatementResponse: - """Response from getting information about a statement.""" + """Representation of the response from getting information about a statement.""" statement_id: str status: StatementStatus @@ -125,7 +125,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str From 47fd60d2b20fcaf1f39300a88224899edb2c0a58 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 11 Jun 2025 15:25:24 +0000 Subject: [PATCH 090/204] introduce strongly typed ChunkInfo Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/base.py | 12 +++++++++++- .../sql/backend/sea/models/responses.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index 6175b4ca0..f63edba72 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -42,6 +42,16 @@ class ExternalLink: http_headers: Optional[Dict[str, str]] = None +@dataclass +class ChunkInfo: + """Information about a chunk in the result set.""" + + chunk_index: int + byte_count: int + row_offset: int + row_count: int + + @dataclass class ResultData: """Result data from a statement execution.""" @@ -73,5 +83,5 @@ class ResultManifest: total_byte_count: int total_chunk_count: int truncated: bool = False - chunks: Optional[List[Dict[str, Any]]] = None + chunks: Optional[List[ChunkInfo]] = None result_compression: Optional[str] = None diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index a8cf0c998..7388af193 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -14,6 +14,7 @@ ResultData, ServiceError, ExternalLink, + ChunkInfo, ) @@ -43,6 +44,18 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) + chunks = None + if "chunks" in manifest_data: + chunks = [ + ChunkInfo( + chunk_index=chunk.get("chunk_index", 0), + byte_count=chunk.get("byte_count", 0), + row_offset=chunk.get("row_offset", 0), + row_count=chunk.get("row_count", 0), + ) + for chunk in manifest_data.get("chunks", []) + ] + return ResultManifest( format=manifest_data.get("format", ""), schema=manifest_data.get("schema", {}), @@ -50,8 +63,9 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: total_byte_count=manifest_data.get("total_byte_count", 0), total_chunk_count=manifest_data.get("total_chunk_count", 0), truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), + chunks=chunks, result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation"), ) From 982fdf2df8480d6ddd8c93b5f8839e4cf5ccce2e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 03:08:31 +0000 Subject: [PATCH 091/204] remove is_volume_operation from response Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/responses.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 7388af193..42dcd356a 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -65,7 +65,6 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: truncated=manifest_data.get("truncated", False), chunks=chunks, result_compression=manifest_data.get("result_compression"), - is_volume_operation=manifest_data.get("is_volume_operation"), ) From 9e14d48fdb03500ad13e098cd963d7a04dadd9a0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:06:47 +0000 Subject: [PATCH 092/204] add is_volume_op and more ResultData fields Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/models/base.py | 8 ++++++++ src/databricks/sql/backend/sea/models/responses.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index f63edba72..b12c26eb0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -58,6 +58,13 @@ class ResultData: data: Optional[List[List[Any]]] = None external_links: Optional[List[ExternalLink]] = None + byte_count: Optional[int] = None + chunk_index: Optional[int] = None + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + row_count: Optional[int] = None + row_offset: Optional[int] = None + attachment: Optional[bytes] = None @dataclass @@ -85,3 +92,4 @@ class ResultManifest: truncated: bool = False chunks: Optional[List[ChunkInfo]] = None result_compression: Optional[str] = None + is_volume_operation: Optional[bool] = None diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 42dcd356a..0baf27ab2 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -65,6 +65,7 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: truncated=manifest_data.get("truncated", False), chunks=chunks, result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation"), ) @@ -93,6 +94,13 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: return ResultData( data=result_data.get("data_array"), external_links=external_links, + byte_count=result_data.get("byte_count"), + chunk_index=result_data.get("chunk_index"), + next_chunk_index=result_data.get("next_chunk_index"), + next_chunk_internal_link=result_data.get("next_chunk_internal_link"), + row_count=result_data.get("row_count"), + row_offset=result_data.get("row_offset"), + attachment=result_data.get("attachment"), ) From 05ee4e78fe72c200e90842d5d916546b08a1a51c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:11:25 +0000 Subject: [PATCH 093/204] add test scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 191 ++++++++++++++++++ .../experimental/tests/test_sea_metadata.py | 98 +++++++++ .../experimental/tests/test_sea_session.py | 71 +++++++ .../experimental/tests/test_sea_sync_query.py | 161 +++++++++++++++ 5 files changed, 521 insertions(+) create mode 100644 examples/experimental/tests/__init__.py create mode 100644 examples/experimental/tests/test_sea_async_query.py create mode 100644 examples/experimental/tests/test_sea_metadata.py create mode 100644 examples/experimental/tests/test_sea_session.py create mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..a776377c3 --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,191 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info( + "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" + ) + cursor.execute_async("SELECT 1 as test_value") + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch enabled" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query asynchronously + cursor = connection.cursor() + logger.info( + "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" + ) + cursor.execute_async("SELECT 1 as test_value") + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + logger.info( + "Successfully retrieved asynchronous query results with cloud fetch disabled" + ) + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..a200d97d3 --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,98 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="customer" + ) + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..516c1bbb8 --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,71 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..07be8aafc --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,161 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info( + "Executing synchronous query with cloud fetch: SELECT 1 as test_value" + ) + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch enabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query + cursor = connection.cursor() + logger.info( + "Executing synchronous query without cloud fetch: SELECT 1 as test_value" + ) + cursor.execute("SELECT 1 as test_value") + logger.info("Query executed successfully with cloud fetch disabled") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) From 2952d8dc2de6adf25ac1c9dd358fc7f5bfc6f495 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:15:01 +0000 Subject: [PATCH 094/204] Revert "Merge branch 'sea-migration' into exec-models-sea" This reverts commit 8bd12d829ea13abf8fc1507fff8cb21751001c67, reversing changes made to 030edf8df3db487b7af8d910ee51240d1339229e. --- .../sql/backend/sea/models/requests.py | 4 +- .../sql/backend/sea/models/responses.py | 2 +- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 130 ++++++++---------- src/databricks/sql/backend/types.py | 8 +- src/databricks/sql/result_set.py | 86 +++++------- src/databricks/sql/utils.py | 6 +- tests/unit/test_client.py | 11 +- tests/unit/test_fetches.py | 39 +++--- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_thrift_backend.py | 106 +++++--------- 11 files changed, 159 insertions(+), 237 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 4c5071dba..8524275d4 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Representation of a request to create a new session.""" + """Request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Representation of a request to delete a session.""" + """Request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 0baf27ab2..4dcd4af02 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -146,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Representation of the response from creating a new session.""" + """Response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..f0b931ee4 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +from typing import Callable, Dict, Any, Optional, Union, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..48e9a115f 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,21 +3,24 @@ import logging import math import time +import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor +from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, + BackendType, + guid_to_hex_id, ExecuteResponse, ) from databricks.sql.backend.utils import guid_to_hex_id - try: import pyarrow except ImportError: @@ -757,13 +760,11 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - - is_direct_results = ( + has_more_rows = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,25 +780,43 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + if direct_results and direct_results.resultSet: + assert direct_results.resultSet.results.startRowOffset == 0 + assert direct_results.resultSetMetadata + + arrow_queue_opt = ResultSetQueueFactory.build_queue( + row_set_type=t_result_set_metadata_resp.resultFormat, + t_row_set=direct_results.resultSet.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) + else: + arrow_queue_opt = None + command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - execute_response = ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) - return execute_response, is_direct_results - def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -822,6 +841,9 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -836,21 +858,25 @@ def get_execution_result( else: schema_bytes = None - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - is_direct_results = resp.hasMoreRows - - status = self.get_query_state(command_id) + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) execute_response = ExecuteResponse( command_id=command_id, status=status, description=description, + has_more_rows=has_more_rows, + results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -860,10 +886,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=resp.results, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -976,14 +999,10 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -991,10 +1010,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1016,14 +1032,10 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1031,10 +1043,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1060,14 +1069,10 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1075,10 +1080,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1108,14 +1110,10 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1123,10 +1121,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1156,14 +1151,10 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1171,10 +1162,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 93bd7d525..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,9 +423,11 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None - result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cf6940bb2..e177d495f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - is_direct_results: bool = False, + has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Parameters: - :param connection: The parent connection - :param backend: The backend client - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - :param command_id: The command ID - :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server - :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue - :param description: column description of the results - :param is_staging_operation: Whether the command is a staging operation + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation @@ -157,47 +157,25 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - t_row_set=None, - max_download_threads: int = 10, - ssl_options=None, - is_direct_results: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Parameters: - :param connection: The parent connection - :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access - :param buffer_size_bytes: Buffer size for fetching results - :param arraysize: Default number of rows to fetch - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - :param t_row_set: The TRowSet containing result data (if available) - :param max_download_threads: Maximum number of download threads for cloud fetch - :param ssl_options: SSL options for cloud fetch - :param is_direct_results: Whether there are more rows to fetch + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed - # Build the results queue if t_row_set is provided - results_queue = None - if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory - - # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( - row_set_type=execute_response.result_format, - t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", - max_download_threads=max_download_threads, - lz4_compressed=execute_response.lz4_compressed, - description=execute_response.description, - ssl_options=ssl_options, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, @@ -207,8 +185,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - is_direct_results=is_direct_results, - results_queue=results_queue, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) @@ -218,7 +196,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -229,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -313,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -338,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -353,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -379,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d7b1b74b4..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2054d01d1..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - is_direct_results=True, + has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,7 +104,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -185,7 +184,6 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -212,7 +210,6 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -257,10 +254,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) - - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -478,6 +472,7 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq + mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,30 +40,25 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - - # Create a mock backend that will return the queue when _fill_results_buffer is called - mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) - - num_cols = len(initial_results[0]) if initial_results else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - description=description, - lz4_compressed=True, + has_more_rows=False, + description=Mock(), + lz4_compressed=Mock(), + results_queue=arrow_queue, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, - t_row_set=None, + thrift_client=None, ) + num_cols = len(initial_results[0]) if initial_results else 0 + rs.description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] return rs @staticmethod @@ -90,19 +85,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - description=description, - lz4_compressed=True, + has_more_rows=True, + description=[ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ], + lz4_compressed=Mock(), + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e4a9e5cdd..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - is_direct_results=False, + has_more_rows=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..8274190fe 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -623,10 +623,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -835,10 +832,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value @@ -882,10 +878,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) + self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -951,14 +947,8 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -983,14 +973,8 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - ) thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1004,10 +988,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1019,7 +1003,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, ), closeOperation=Mock(), @@ -1035,12 +1019,11 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - ( - execute_response, - has_more_rows_result, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) - self.assertEqual(is_direct_results, has_more_rows_result) + self.assertEqual(has_more_rows, execute_response.has_more_rows) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1049,10 +1032,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1065,7 +1048,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1098,7 +1081,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(is_direct_results, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1153,10 +1136,9 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1169,15 +1151,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1189,10 +1170,9 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1205,13 +1185,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1222,10 +1201,9 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1238,8 +1216,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1251,7 +1228,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1264,10 +1241,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1280,8 +1256,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1295,7 +1270,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1310,10 +1285,9 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1326,8 +1300,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1341,7 +1314,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -2230,23 +2203,14 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class, mock_result_set + self, mock_handle_execute_response, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value - # Set up the mock to return a tuple with two values - mock_execute_response = Mock() - mock_arrow_schema = Mock() - mock_handle_execute_response.return_value = ( - mock_execute_response, - mock_arrow_schema, - ) - # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From cbace3f52c025d2b414c4169555f9daeaa27581d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:20:12 +0000 Subject: [PATCH 095/204] Revert "Merge branch 'exec-models-sea' into exec-phase-sea" This reverts commit be1997e0d6b6cf0f5499db2381971ec3a015a2f7, reversing changes made to 37813ba6d1fe06d7f9f10d510a059b88dc552496. --- examples/experimental/sea_connector_test.py | 68 +-- src/databricks/sql/backend/sea/backend.py | 480 ++++++++++++++++-- src/databricks/sql/backend/sea/models/base.py | 20 +- .../sql/backend/sea/models/requests.py | 14 +- .../sql/backend/sea/models/responses.py | 29 +- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 137 +++-- src/databricks/sql/backend/types.py | 8 +- src/databricks/sql/result_set.py | 159 ++++-- src/databricks/sql/utils.py | 6 +- tests/unit/test_client.py | 11 +- tests/unit/test_fetches.py | 39 +- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_sea_backend.py | 453 +++++++++++++---- tests/unit/test_sea_result_set.py | 200 ++++++++ tests/unit/test_thrift_backend.py | 138 +++-- 16 files changed, 1300 insertions(+), 466 deletions(-) create mode 100644 tests/unit/test_sea_result_set.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 2553a2b20..0db326894 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,7 +10,8 @@ import subprocess from typing import List, Tuple -logging.basicConfig(level=logging.DEBUG) +# Configure logging +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) TEST_MODULES = [ @@ -87,48 +88,29 @@ def print_summary(results: List[Tuple[str, bool]]) -> None: logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") logger.info(f"{'=' * 50}") - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) - - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + +if __name__ == "__main__": + # Check if required environment variables are set + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) + logger.error("Please set these variables before running the tests.") sys.exit(1) - - logger.info("SEA session test completed successfully") -if __name__ == "__main__": - test_sea_session() + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..6d627162d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,43 @@ import logging +import time import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError -from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +from databricks.sql.backend.types import ( + SessionId, + CommandId, + CommandState, + BackendType, + ExecuteResponse, ) -from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions +from databricks.sql.backend.sea.models.base import ( + ResultData, + ExternalLink, + ResultManifest, +) from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +75,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -64,6 +87,8 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" + CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -96,6 +121,7 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) + self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -252,6 +278,19 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) + @staticmethod + def is_session_configuration_parameter_supported(name: str) -> bool: + """ + Check if a session configuration parameter is supported. + + Args: + name: The name of the session configuration parameter + + Returns: + True if the parameter is supported, False otherwise + """ + return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP + @staticmethod def get_allowed_session_configurations() -> List[str]: """ @@ -262,8 +301,182 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - # == Not Implemented Operations == - # These methods will be implemented in future iterations + def _get_schema_bytes(self, sea_response) -> Optional[bytes]: + """ + Extract schema bytes from the SEA response. + + For ARROW format, we need to get the schema bytes from the first chunk. + If the first chunk is not available, we need to get it from the server. + + Args: + sea_response: The response from the SEA API + + Returns: + bytes: The schema bytes or None if not available + """ + import requests + import lz4.frame + + # Check if we have the first chunk in the response + result_data = sea_response.get("result", {}) + external_links = result_data.get("external_links", []) + + if not external_links: + return None + + # Find the first chunk (chunk_index = 0) + first_chunk = None + for link in external_links: + if link.get("chunk_index") == 0: + first_chunk = link + break + + if not first_chunk: + # Try to fetch the first chunk from the server + statement_id = sea_response.get("statement_id") + if not statement_id: + return None + + chunks_response = self.get_chunk_links(statement_id, 0) + if not chunks_response.external_links: + return None + + first_chunk = chunks_response.external_links[0].__dict__ + + # Download the first chunk to get the schema bytes + external_link = first_chunk.get("external_link") + http_headers = first_chunk.get("http_headers", {}) + + if not external_link: + return None + + # Use requests to download the first chunk + http_response = requests.get( + external_link, + headers=http_headers, + verify=self.ssl_options.tls_verify, + ) + + if http_response.status_code != 200: + raise Error(f"Failed to download schema bytes: {http_response.text}") + + # Extract schema bytes from the Arrow file + # The schema is at the beginning of the file + data = http_response.content + if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": + data = lz4.frame.decompress(data) + + # Return the schema bytes + return data + + def _results_message_to_execute_response(self, sea_response, command_id): + """ + Convert a SEA response to an ExecuteResponse and extract result data. + + Args: + sea_response: The response from the SEA API + command_id: The command ID + + Returns: + tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, + result data object, and manifest object + """ + # Extract status + status_data = sea_response.get("status", {}) + state = CommandState.from_sea_state(status_data.get("state", "")) + + # Extract description from manifest + description = None + manifest_data = sea_response.get("manifest", {}) + schema_data = manifest_data.get("schema", {}) + columns_data = schema_data.get("columns", []) + + if columns_data: + columns = [] + for col_data in columns_data: + if not isinstance(col_data, dict): + continue + + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + description = columns if columns else None + + # Extract schema bytes for Arrow format + schema_bytes = None + format = manifest_data.get("format") + if format == "ARROW_STREAM": + # For ARROW format, we need to get the schema bytes + schema_bytes = self._get_schema_bytes(sea_response) + + # Check for compression + lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" + + # Initialize result_data_obj and manifest_obj + result_data_obj = None + manifest_obj = None + + result_data = sea_response.get("result", {}) + if result_data: + # Convert external links + external_links = None + if "external_links" in result_data: + external_links = [] + for link_data in result_data["external_links"]: + external_links.append( + ExternalLink( + external_link=link_data.get("external_link", ""), + expiration=link_data.get("expiration", ""), + chunk_index=link_data.get("chunk_index", 0), + byte_count=link_data.get("byte_count", 0), + row_count=link_data.get("row_count", 0), + row_offset=link_data.get("row_offset", 0), + next_chunk_index=link_data.get("next_chunk_index"), + next_chunk_internal_link=link_data.get( + "next_chunk_internal_link" + ), + http_headers=link_data.get("http_headers", {}), + ) + ) + + # Create the result data object + result_data_obj = ResultData( + data=result_data.get("data_array"), external_links=external_links + ) + + # Create the manifest object + manifest_obj = ResultManifest( + format=manifest_data.get("format", ""), + schema=manifest_data.get("schema", {}), + total_row_count=manifest_data.get("total_row_count", 0), + total_byte_count=manifest_data.get("total_byte_count", 0), + total_chunk_count=manifest_data.get("total_chunk_count", 0), + truncated=manifest_data.get("truncated", False), + chunks=manifest_data.get("chunks"), + result_compression=manifest_data.get("result_compression"), + ) + + execute_response = ExecuteResponse( + command_id=command_id, + status=state, + description=description, + has_been_closed_server_side=False, + lz4_compressed=lz4_compressed, + is_staging_operation=False, + arrow_schema_bytes=schema_bytes, + result_format=manifest_data.get("format"), + ) + + return execute_response, result_data_obj, manifest_obj def execute_command( self, @@ -274,41 +487,230 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else None + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + # Convert the response to an ExecuteResponse and extract result data + ( + execute_response, + result_data, + manifest, + ) = self._results_message_to_execute_response(response_data, command_id) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + result_data=result_data, + manifest=manifest, ) # == Metadata Operations == @@ -319,9 +721,9 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + raise NotImplementedError("get_catalogs is not implemented for SEA backend") def get_schemas( self, @@ -331,9 +733,9 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + raise NotImplementedError("get_schemas is not implemented for SEA backend") def get_tables( self, @@ -345,9 +747,9 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_tables is not implemented for SEA backend") def get_columns( self, @@ -359,6 +761,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + raise NotImplementedError("get_columns is not implemented for SEA backend") diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index b12c26eb0..6175b4ca0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -42,29 +42,12 @@ class ExternalLink: http_headers: Optional[Dict[str, str]] = None -@dataclass -class ChunkInfo: - """Information about a chunk in the result set.""" - - chunk_index: int - byte_count: int - row_offset: int - row_count: int - - @dataclass class ResultData: """Result data from a statement execution.""" data: Optional[List[List[Any]]] = None external_links: Optional[List[ExternalLink]] = None - byte_count: Optional[int] = None - chunk_index: Optional[int] = None - next_chunk_index: Optional[int] = None - next_chunk_internal_link: Optional[str] = None - row_count: Optional[int] = None - row_offset: Optional[int] = None - attachment: Optional[bytes] = None @dataclass @@ -90,6 +73,5 @@ class ResultManifest: total_byte_count: int total_chunk_count: int truncated: bool = False - chunks: Optional[List[ChunkInfo]] = None + chunks: Optional[List[Dict[str, Any]]] = None result_compression: Optional[str] = None - is_volume_operation: Optional[bool] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 4c5071dba..58921d793 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -10,7 +10,7 @@ @dataclass class StatementParameter: - """Representation of a parameter for a SQL statement.""" + """Parameter for a SQL statement.""" name: str value: Optional[str] = None @@ -19,7 +19,7 @@ class StatementParameter: @dataclass class ExecuteStatementRequest: - """Representation of a request to execute a SQL statement.""" + """Request to execute a SQL statement.""" session_id: str statement: str @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class GetStatementRequest: - """Representation of a request to get information about a statement.""" + """Request to get information about a statement.""" statement_id: str @@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CancelStatementRequest: - """Representation of a request to cancel a statement.""" + """Request to cancel a statement.""" statement_id: str @@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CloseStatementRequest: - """Representation of a request to close a statement.""" + """Request to close a statement.""" statement_id: str @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Representation of a request to create a new session.""" + """Request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Representation of a request to delete a session.""" + """Request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 0baf27ab2..c16f19da3 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -14,7 +14,6 @@ ResultData, ServiceError, ExternalLink, - ChunkInfo, ) @@ -44,18 +43,6 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) - chunks = None - if "chunks" in manifest_data: - chunks = [ - ChunkInfo( - chunk_index=chunk.get("chunk_index", 0), - byte_count=chunk.get("byte_count", 0), - row_offset=chunk.get("row_offset", 0), - row_count=chunk.get("row_count", 0), - ) - for chunk in manifest_data.get("chunks", []) - ] - return ResultManifest( format=manifest_data.get("format", ""), schema=manifest_data.get("schema", {}), @@ -63,9 +50,8 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: total_byte_count=manifest_data.get("total_byte_count", 0), total_chunk_count=manifest_data.get("total_chunk_count", 0), truncated=manifest_data.get("truncated", False), - chunks=chunks, + chunks=manifest_data.get("chunks"), result_compression=manifest_data.get("result_compression"), - is_volume_operation=manifest_data.get("is_volume_operation"), ) @@ -94,19 +80,12 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: return ResultData( data=result_data.get("data_array"), external_links=external_links, - byte_count=result_data.get("byte_count"), - chunk_index=result_data.get("chunk_index"), - next_chunk_index=result_data.get("next_chunk_index"), - next_chunk_internal_link=result_data.get("next_chunk_internal_link"), - row_count=result_data.get("row_count"), - row_offset=result_data.get("row_offset"), - attachment=result_data.get("attachment"), ) @dataclass class ExecuteStatementResponse: - """Representation of the response from executing a SQL statement.""" + """Response from executing a SQL statement.""" statement_id: str status: StatementStatus @@ -126,7 +105,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": @dataclass class GetStatementResponse: - """Representation of the response from getting information about a statement.""" + """Response from getting information about a statement.""" statement_id: str status: StatementStatus @@ -146,7 +125,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Representation of the response from creating a new session.""" + """Response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..f0b931ee4 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +from typing import Callable, Dict, Any, Optional, Union, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e824de1c2..f41f4b6d8 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,20 +3,22 @@ import logging import math import time +import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor +from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, + BackendType, + guid_to_hex_id, ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id - try: import pyarrow @@ -757,13 +759,11 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - - is_direct_results = ( + has_more_rows = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,25 +779,43 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + if direct_results and direct_results.resultSet: + assert direct_results.resultSet.results.startRowOffset == 0 + assert direct_results.resultSetMetadata + + arrow_queue_opt = ResultSetQueueFactory.build_queue( + row_set_type=t_result_set_metadata_resp.resultFormat, + t_row_set=direct_results.resultSet.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) + else: + arrow_queue_opt = None + command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - execute_response = ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) - return execute_response, is_direct_results - def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -822,6 +840,9 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -836,9 +857,15 @@ def get_execution_result( else: schema_bytes = None - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - is_direct_results = resp.hasMoreRows + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) status = self.get_query_state(command_id) @@ -846,11 +873,11 @@ def get_execution_result( command_id=command_id, status=status, description=description, + has_more_rows=has_more_rows, + results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, - result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -860,10 +887,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=resp.results, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -894,7 +918,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Unknown command state: {operation_state}") + raise ValueError(f"Invalid operation state: {operation_state}") return state @staticmethod @@ -976,14 +1000,10 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -991,10 +1011,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1016,14 +1033,10 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1031,10 +1044,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1060,14 +1070,10 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1075,10 +1081,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1108,14 +1111,10 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1123,10 +1122,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1156,14 +1152,10 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( + execute_response, arrow_schema_bytes = self._handle_execute_response( resp, cursor ) - t_row_set = None - if resp.directResults and resp.directResults.resultSet: - t_row_set = resp.directResults.resultSet.results - return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1171,10 +1163,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - t_row_set=t_row_set, - max_download_threads=self.max_download_threads, - ssl_options=self._ssl_options, - is_direct_results=is_direct_results, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1188,7 +1177,11 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 93bd7d525..3d24a09a1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,9 +423,11 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None - result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index cf6940bb2..7ea4976c1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - is_direct_results: bool = False, + has_more_rows: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Parameters: - :param connection: The parent connection - :param backend: The backend client - :param arraysize: The max number of rows to fetch at a time (PEP-249) - :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - :param command_id: The command ID - :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server - :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue - :param description: column description of the results - :param is_staging_operation: Whether the command is a staging operation + Args: + connection: The parent connection + backend: The backend client + arraysize: The max number of rows to fetch at a time (PEP-249) + buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + command_id: The command ID + status: The command status + has_been_closed_server_side: Whether the command has been closed on the server + has_more_rows: Whether the command has more rows + results_queue: The results queue + description: column description of the results + is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows self.results = results_queue self._is_staging_operation = is_staging_operation @@ -157,47 +157,25 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - t_row_set=None, - max_download_threads: int = 10, - ssl_options=None, - is_direct_results: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Parameters: - :param connection: The parent connection - :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access - :param buffer_size_bytes: Buffer size for fetching results - :param arraysize: Default number of rows to fetch - :param use_cloud_fetch: Whether to use cloud fetch for retrieving results - :param t_row_set: The TRowSet containing result data (if available) - :param max_download_threads: Maximum number of download threads for cloud fetch - :param ssl_options: SSL options for cloud fetch - :param is_direct_results: Whether there are more rows to fetch + Args: + connection: The parent connection + execute_response: Response from the execute command + thrift_client: The ThriftDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed - # Build the results queue if t_row_set is provided - results_queue = None - if t_row_set and execute_response.result_format is not None: - from databricks.sql.utils import ResultSetQueueFactory - - # Create the results queue using the provided format - results_queue = ResultSetQueueFactory.build_queue( - row_set_type=execute_response.result_format, - t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", - max_download_threads=max_download_threads, - lz4_compressed=execute_response.lz4_compressed, - description=execute_response.description, - ssl_options=ssl_options, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, @@ -207,8 +185,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - is_direct_results=is_direct_results, - results_queue=results_queue, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) @@ -218,7 +196,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, is_direct_results = self.backend.fetch_results( + results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -229,7 +207,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.is_direct_results = is_direct_results + self.has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -313,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -338,7 +316,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.is_direct_results + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -353,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -379,7 +357,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.is_direct_results: + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -438,3 +416,76 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + +class SeaResultSet(ResultSet): + """ResultSet implementation for the SEA backend.""" + + def __init__( + self, + connection: "Connection", + execute_response: "ExecuteResponse", + sea_client: "SeaDatabricksClient", + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) + """ + + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d7b1b74b4..edb13ef6d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: Optional[List[List[Any]]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2054d01d1..090ec255e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - is_direct_results=True, + has_more_rows=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,7 +104,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None - mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -185,7 +184,6 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -212,7 +210,6 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -257,10 +254,7 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() - mock_backend.fetch_results.return_value = (Mock(), False) - - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -478,6 +472,7 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq + mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..7249a59e6 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,30 +40,25 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) - - # Create a mock backend that will return the queue when _fill_results_buffer is called - mock_thrift_backend = Mock(spec=ThriftDatabricksClient) - mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) - - num_cols = len(initial_results[0]) if initial_results else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - description=description, - lz4_compressed=True, + has_more_rows=False, + description=Mock(), + lz4_compressed=Mock(), + results_queue=arrow_queue, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, - t_row_set=None, + thrift_client=None, ) + num_cols = len(initial_results[0]) if initial_results else 0 + rs.description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] return rs @staticmethod @@ -90,19 +85,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 - description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] - rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - description=description, - lz4_compressed=True, + has_more_rows=True, + description=[ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ], + lz4_compressed=Mock(), + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e4a9e5cdd..7e025cf82 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - is_direct_results=False, + has_more_rows=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..2fa362b8e 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,348 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response + + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, } - assert set(allowed_configs) == expected_keys - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + } + mock_http_client._make_request.return_value = execute_response + + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 000000000..b691872af --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,200 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + return mock_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.command_id == execute_response.command_id + assert result_set.status == CommandState.SUCCEEDED + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set.description == execute_response.description + + def test_close(self, mock_connection, mock_sea_client, execute_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, execute_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, execute_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57b5e9b58..b8de970db 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,14 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -835,10 +839,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value @@ -882,10 +885,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) + self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -920,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -948,21 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -982,15 +987,15 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req - tcli_service_instance.GetOperationStatus.return_value = ( - ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, ) + tcli_service_instance.GetOperationStatus.return_value = op_state + tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1004,10 +1009,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1019,7 +1024,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, ), closeOperation=Mock(), @@ -1035,12 +1040,11 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - ( - execute_response, - has_more_rows_result, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response, _ = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) - self.assertEqual(is_direct_results, has_more_rows_result) + self.assertEqual(has_more_rows, execute_response.has_more_rows) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1049,10 +1053,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for is_direct_results, resp_type in itertools.product( + for has_more_rows, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): + with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1065,7 +1069,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=is_direct_results, + hasMoreRows=has_more_rows, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1098,7 +1102,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(is_direct_results, has_more_rows_resp) + self.assertEqual(has_more_rows, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1153,10 +1157,9 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1169,15 +1172,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1189,10 +1191,9 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1205,13 +1206,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1222,10 +1222,9 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1238,8 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1251,7 +1249,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1264,10 +1262,9 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1280,8 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1295,7 +1291,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1310,10 +1306,9 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class, mock_result_set + self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1326,8 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1341,7 +1335,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertEqual(result, mock_result_set.return_value) + self.assertIsInstance(result, ResultSet) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -1673,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2230,23 +2226,15 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class, mock_result_set + self, mock_handle_execute_response, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value - # Set up the mock to return a tuple with two values - mock_execute_response = Mock() - mock_arrow_schema = Mock() - mock_handle_execute_response.return_value = ( - mock_execute_response, - mock_arrow_schema, - ) - # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From c075b07164aeaf3d571aeb35c6d7227b92436aeb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:22:30 +0000 Subject: [PATCH 096/204] change logging level Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 0db326894..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -10,8 +10,7 @@ import subprocess from typing import List, Tuple -# Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) TEST_MODULES = [ From c62f76dce2d17f842708489da04c7a8d4255cf06 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:37:12 +0000 Subject: [PATCH 097/204] remove un-necessary changes Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 161 ++++++------------ src/databricks/sql/backend/sea/models/base.py | 20 ++- .../sql/backend/sea/models/requests.py | 14 +- .../sql/backend/sea/models/responses.py | 29 +++- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 137 ++++++++------- src/databricks/sql/backend/types.py | 8 +- src/databricks/sql/result_set.py | 159 ++++++----------- src/databricks/sql/utils.py | 6 +- 9 files changed, 240 insertions(+), 296 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index edd171b05..abe6bd1ab 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,115 +1,66 @@ -""" -Main script to run all SEA connector tests. - -This script runs all the individual test modules and displays -a summary of test results with visual indicators. -""" import os import sys import logging -import subprocess -from typing import List, Tuple +from databricks.sql.client import Connection logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -TEST_MODULES = [ - "test_sea_session", - "test_sea_sync_query", - "test_sea_async_query", - "test_sea_metadata", -] - - -def run_test_module(module_name: str) -> bool: - """Run a test module and return success status.""" - module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" - ) - - # Simply run the module as a script - each module handles its own test execution - result = subprocess.run( - [sys.executable, module_path], capture_output=True, text=True - ) - - # Log the output from the test module - if result.stdout: - for line in result.stdout.strip().split("\n"): - logger.info(line) - - if result.stderr: - for line in result.stderr.strip().split("\n"): - logger.error(line) - - return result.returncode == 0 - - -def run_tests() -> List[Tuple[str, bool]]: - """Run all tests and return results.""" - results = [] - - for module_name in TEST_MODULES: - try: - logger.info(f"\n{'=' * 50}") - logger.info(f"Running test: {module_name}") - logger.info(f"{'-' * 50}") - - success = run_test_module(module_name) - results.append((module_name, success)) - - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"Test {module_name}: {status}") - - except Exception as e: - logger.error(f"Error loading or running test {module_name}: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - results.append((module_name, False)) - - return results - - -def print_summary(results: List[Tuple[str, bool]]) -> None: - """Print a summary of test results.""" - logger.info(f"\n{'=' * 50}") - logger.info("TEST SUMMARY") - logger.info(f"{'-' * 50}") - - passed = sum(1 for _, success in results if success) - total = len(results) - - for module_name, success in results: - status = "✅ PASSED" if success else "❌ FAILED" - logger.info(f"{status} - {module_name}") - - logger.info(f"{'-' * 50}") - logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") - logger.info(f"{'=' * 50}") - - -if __name__ == "__main__": - # Check if required environment variables are set - required_vars = [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_HTTP_PATH", - "DATABRICKS_TOKEN", - ] - missing_vars = [var for var in required_vars if not os.environ.get(var)] - - if missing_vars: - logger.error( - f"Missing required environment variables: {', '.join(missing_vars)}" +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + sys.exit(1) + + logger.info(f"Connecting to {server_hostname}") + logger.info(f"HTTP Path: {http_path}") + if catalog: + logger.info(f"Using catalog: {catalog}") + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client" # add custom user agent ) - logger.error("Please set these variables before running the tests.") + + logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") + logger.info(f"backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + logger.error(traceback.format_exc()) sys.exit(1) + + logger.info("SEA session test completed successfully") - # Run all tests - results = run_tests() - - # Print summary - print_summary(results) - - # Exit with appropriate status code - all_passed = all(success for _, success in results) - sys.exit(0 if all_passed else 1) +if __name__ == "__main__": + test_sea_session() diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py index 6175b4ca0..b12c26eb0 100644 --- a/src/databricks/sql/backend/sea/models/base.py +++ b/src/databricks/sql/backend/sea/models/base.py @@ -42,12 +42,29 @@ class ExternalLink: http_headers: Optional[Dict[str, str]] = None +@dataclass +class ChunkInfo: + """Information about a chunk in the result set.""" + + chunk_index: int + byte_count: int + row_offset: int + row_count: int + + @dataclass class ResultData: """Result data from a statement execution.""" data: Optional[List[List[Any]]] = None external_links: Optional[List[ExternalLink]] = None + byte_count: Optional[int] = None + chunk_index: Optional[int] = None + next_chunk_index: Optional[int] = None + next_chunk_internal_link: Optional[str] = None + row_count: Optional[int] = None + row_offset: Optional[int] = None + attachment: Optional[bytes] = None @dataclass @@ -73,5 +90,6 @@ class ResultManifest: total_byte_count: int total_chunk_count: int truncated: bool = False - chunks: Optional[List[Dict[str, Any]]] = None + chunks: Optional[List[ChunkInfo]] = None result_compression: Optional[str] = None + is_volume_operation: Optional[bool] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 58921d793..4c5071dba 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -10,7 +10,7 @@ @dataclass class StatementParameter: - """Parameter for a SQL statement.""" + """Representation of a parameter for a SQL statement.""" name: str value: Optional[str] = None @@ -19,7 +19,7 @@ class StatementParameter: @dataclass class ExecuteStatementRequest: - """Request to execute a SQL statement.""" + """Representation of a request to execute a SQL statement.""" session_id: str statement: str @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class GetStatementRequest: - """Request to get information about a statement.""" + """Representation of a request to get information about a statement.""" statement_id: str @@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CancelStatementRequest: - """Request to cancel a statement.""" + """Representation of a request to cancel a statement.""" statement_id: str @@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CloseStatementRequest: - """Request to close a statement.""" + """Representation of a request to close a statement.""" statement_id: str @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Request to create a new session.""" + """Representation of a request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Request to delete a session.""" + """Representation of a request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index c16f19da3..0baf27ab2 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -14,6 +14,7 @@ ResultData, ServiceError, ExternalLink, + ChunkInfo, ) @@ -43,6 +44,18 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) + chunks = None + if "chunks" in manifest_data: + chunks = [ + ChunkInfo( + chunk_index=chunk.get("chunk_index", 0), + byte_count=chunk.get("byte_count", 0), + row_offset=chunk.get("row_offset", 0), + row_count=chunk.get("row_count", 0), + ) + for chunk in manifest_data.get("chunks", []) + ] + return ResultManifest( format=manifest_data.get("format", ""), schema=manifest_data.get("schema", {}), @@ -50,8 +63,9 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: total_byte_count=manifest_data.get("total_byte_count", 0), total_chunk_count=manifest_data.get("total_chunk_count", 0), truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), + chunks=chunks, result_compression=manifest_data.get("result_compression"), + is_volume_operation=manifest_data.get("is_volume_operation"), ) @@ -80,12 +94,19 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: return ResultData( data=result_data.get("data_array"), external_links=external_links, + byte_count=result_data.get("byte_count"), + chunk_index=result_data.get("chunk_index"), + next_chunk_index=result_data.get("next_chunk_index"), + next_chunk_internal_link=result_data.get("next_chunk_internal_link"), + row_count=result_data.get("row_count"), + row_offset=result_data.get("row_offset"), + attachment=result_data.get("attachment"), ) @dataclass class ExecuteStatementResponse: - """Response from executing a SQL statement.""" + """Representation of the response from executing a SQL statement.""" statement_id: str status: StatementStatus @@ -105,7 +126,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": @dataclass class GetStatementResponse: - """Response from getting information about a statement.""" + """Representation of the response from getting information about a statement.""" statement_id: str status: StatementStatus @@ -125,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index f41f4b6d8..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,22 +3,20 @@ import logging import math import time -import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, - guid_to_hex_id, ExecuteResponse, ) +from databricks.sql.backend.utils import guid_to_hex_id + try: import pyarrow @@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( + + is_direct_results = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ( - ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - ), - schema_bytes, + execute_response = ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, is_direct_results + def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -840,9 +822,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -857,15 +836,9 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + is_direct_results = resp.hasMoreRows status = self.get_query_state(command_id) @@ -873,11 +846,11 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, - results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -887,7 +860,10 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=schema_bytes, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) state = CommandState.from_thrift_state(operation_state) if state is None: - raise ValueError(f"Invalid operation state: {operation_state}") + raise ValueError(f"Unknown command state: {operation_state}") return state @staticmethod @@ -1000,10 +976,14 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1011,7 +991,10 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_catalogs( @@ -1033,10 +1016,14 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1044,7 +1031,10 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_schemas( @@ -1070,10 +1060,14 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1081,7 +1075,10 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_tables( @@ -1111,10 +1108,14 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1122,7 +1123,10 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_columns( @@ -1152,10 +1156,14 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1163,7 +1171,10 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _handle_execute_response(self, resp, cursor): @@ -1177,11 +1188,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - ( - execute_response, - arrow_schema_bytes, - ) = self._results_message_to_execute_response(resp, final_operation_state) - return execute_response, arrow_schema_bytes + return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..93bd7d525 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,11 +423,9 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None + description: Optional[List[Tuple]] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 7ea4976c1..cf6940bb2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, + is_direct_results: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -49,18 +49,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Args: - connection: The parent connection - backend: The backend client - arraysize: The max number of rows to fetch at a time (PEP-249) - buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - command_id: The command ID - status: The command status - has_been_closed_server_side: Whether the command has been closed on the server - has_more_rows: Whether the command has more rows - results_queue: The results queue - description: column description of the results - is_staging_operation: Whether the command is a staging operation + Parameters: + :param connection: The parent connection + :param backend: The backend client + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + :param command_id: The command ID + :param status: The command status + :param has_been_closed_server_side: Whether the command has been closed on the server + :param is_direct_results: Whether the command has more rows + :param results_queue: The results queue + :param description: column description of the results + :param is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -72,7 +72,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation @@ -157,25 +157,47 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + is_direct_results: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Args: - connection: The parent connection - execute_response: Response from the execute command - thrift_client: The ThriftDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set + Parameters: + :param connection: The parent connection + :param execute_response: Response from the execute command + :param thrift_client: The ThriftDatabricksClient instance for direct access + :param buffer_size_bytes: Buffer size for fetching results + :param arraysize: Default number of rows to fetch + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results + :param t_row_set: The TRowSet containing result data (if available) + :param max_download_threads: Maximum number of download threads for cloud fetch + :param ssl_options: SSL options for cloud fetch + :param is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes - self._arrow_schema_bytes = arrow_schema_bytes + self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch self.lz4_compressed = execute_response.lz4_compressed + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + # Call parent constructor with common attributes super().__init__( connection=connection, @@ -185,8 +207,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + is_direct_results=is_direct_results, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, ) @@ -196,7 +218,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, has_more_rows = self.backend.fetch_results( + results, is_direct_results = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -207,7 +229,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -291,7 +313,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -316,7 +338,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -331,7 +353,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -357,7 +379,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -416,76 +438,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for the SEA backend.""" - - def __init__( - self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - execute_response: Response from the execute command (new style) - sea_response: Direct SEA response (legacy style) - """ - - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - ) - - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - """ - - raise NotImplementedError("fetchone is not implemented for SEA backend") - - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - An empty sequence is returned when no more rows are available. - """ - - raise NotImplementedError("fetchmany is not implemented for SEA backend") - - def fetchall(self) -> List[Row]: - """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. - """ - raise NotImplementedError("fetchall is not implemented for SEA backend") - - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") - - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..d7b1b74b4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. From 199402eb6f09e8889cfb426935d2ac911543119a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:39:18 +0000 Subject: [PATCH 098/204] remove excess changes Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 191 ------------------ .../experimental/tests/test_sea_metadata.py | 98 --------- .../experimental/tests/test_sea_session.py | 71 ------- .../experimental/tests/test_sea_sync_query.py | 161 --------------- 5 files changed, 521 deletions(-) delete mode 100644 examples/experimental/tests/__init__.py delete mode 100644 examples/experimental/tests/test_sea_async_query.py delete mode 100644 examples/experimental/tests/test_sea_metadata.py delete mode 100644 examples/experimental/tests/test_sea_session.py delete mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py deleted file mode 100644 index a776377c3..000000000 --- a/examples/experimental/tests/test_sea_async_query.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Test for SEA asynchronous query execution functionality. -""" -import os -import sys -import logging -import time -from databricks.sql.client import Connection -from databricks.sql.backend.types import CommandState - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_async_query_with_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch enabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch enabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_without_cloud_fetch(): - """ - Test executing a query asynchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for asynchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query asynchronously - cursor = connection.cursor() - logger.info( - "Executing asynchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute_async("SELECT 1 as test_value") - logger.info( - "Asynchronous query submitted successfully with cloud fetch disabled" - ) - - # Check query state - logger.info("Checking query state...") - while cursor.is_query_pending(): - logger.info("Query is still pending, waiting...") - time.sleep(1) - - logger.info("Query is no longer pending, getting results...") - cursor.get_async_execution_result() - logger.info( - "Successfully retrieved asynchronous query results with cloud fetch disabled" - ) - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_async_query_exec(): - """ - Run both asynchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() - logger.info( - f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() - logger.info( - f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_async_query_exec() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py deleted file mode 100644 index a200d97d3..000000000 --- a/examples/experimental/tests/test_sea_metadata.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -Test for SEA metadata functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_metadata(): - """ - Test metadata operations using the SEA backend. - - This function connects to a Databricks SQL endpoint using the SEA backend, - and executes metadata operations like catalogs(), schemas(), tables(), and columns(). - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - if not catalog: - logger.error( - "DATABRICKS_CATALOG environment variable is required for metadata tests." - ) - return False - - try: - # Create connection - logger.info("Creating connection for metadata operations") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Test catalogs - cursor = connection.cursor() - logger.info("Fetching catalogs...") - cursor.catalogs() - logger.info("Successfully fetched catalogs") - - # Test schemas - logger.info(f"Fetching schemas for catalog '{catalog}'...") - cursor.schemas(catalog_name=catalog) - logger.info("Successfully fetched schemas") - - # Test tables - logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") - cursor.tables(catalog_name=catalog, schema_name="default") - logger.info("Successfully fetched tables") - - # Test columns for a specific table - # Using a common table that should exist in most environments - logger.info( - f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." - ) - cursor.columns( - catalog_name=catalog, schema_name="default", table_name="customer" - ) - logger.info("Successfully fetched columns") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error during SEA metadata test: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_metadata() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py deleted file mode 100644 index 516c1bbb8..000000000 --- a/examples/experimental/tests/test_sea_session.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Test for SEA session management functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - logger.info(f"Backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return False - - -if __name__ == "__main__": - success = test_sea_session() - sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py deleted file mode 100644 index 07be8aafc..000000000 --- a/examples/experimental/tests/test_sea_sync_query.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Test for SEA synchronous query execution functionality. -""" -import os -import sys -import logging -from databricks.sql.client import Connection - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def test_sea_sync_query_with_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch enabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch enabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch enabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=True, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query with cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch enabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_without_cloud_fetch(): - """ - Test executing a query synchronously using the SEA backend with cloud fetch disabled. - - This function connects to a Databricks SQL endpoint using the SEA backend, - executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. - """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error( - "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." - ) - return False - - try: - # Create connection with cloud fetch disabled - logger.info( - "Creating connection for synchronous query execution with cloud fetch disabled" - ) - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client", - use_cloud_fetch=False, - enable_query_result_lz4_compression=False, - ) - - logger.info( - f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" - ) - - # Execute a simple query - cursor = connection.cursor() - logger.info( - "Executing synchronous query without cloud fetch: SELECT 1 as test_value" - ) - cursor.execute("SELECT 1 as test_value") - logger.info("Query executed successfully with cloud fetch disabled") - - # Close resources - cursor.close() - connection.close() - logger.info("Successfully closed SEA session") - - return True - - except Exception as e: - logger.error( - f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" - ) - import traceback - - logger.error(traceback.format_exc()) - return False - - -def test_sea_sync_query_exec(): - """ - Run both synchronous query tests and return overall success. - """ - with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() - logger.info( - f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" - ) - - without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() - logger.info( - f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" - ) - - return with_cloud_fetch_success and without_cloud_fetch_success - - -if __name__ == "__main__": - success = test_sea_sync_query_exec() - sys.exit(0 if success else 1) From 8ac574ba46d7e2349fba105857e9ca2b7963e32b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 06:41:22 +0000 Subject: [PATCH 099/204] remove excess changes Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 11 +- tests/unit/test_fetches.py | 39 +++--- tests/unit/test_fetches_bench.py | 2 +- tests/unit/test_sea_result_set.py | 200 ------------------------------ tests/unit/test_thrift_backend.py | 138 +++++++++++---------- 5 files changed, 106 insertions(+), 284 deletions(-) delete mode 100644 tests/unit/test_sea_result_set.py diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 090ec255e..2054d01d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -210,6 +212,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -254,7 +257,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -472,7 +478,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7249a59e6..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,25 +40,30 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - results_queue=arrow_queue, + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] return rs @staticmethod @@ -85,19 +90,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - results_queue=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 7e025cf82..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py deleted file mode 100644 index b691872af..000000000 --- a/tests/unit/test_sea_result_set.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Tests for the SeaResultSet class. - -This module contains tests for the SeaResultSet class, which implements -the result set functionality for the SEA (Statement Execution API) backend. -""" - -import pytest -from unittest.mock import patch, MagicMock, Mock - -from databricks.sql.result_set import SeaResultSet -from databricks.sql.backend.types import CommandId, CommandState, BackendType - - -class TestSeaResultSet: - """Test suite for the SeaResultSet class.""" - - @pytest.fixture - def mock_connection(self): - """Create a mock connection.""" - connection = Mock() - connection.open = True - return connection - - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() - - @pytest.fixture - def execute_response(self): - """Create a sample execute response.""" - mock_response = Mock() - mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") - mock_response.status = CommandState.SUCCEEDED - mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None - mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) - ] - mock_response.is_staging_operation = False - return mock_response - - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): - """Test initializing SeaResultSet with an execute response.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Verify basic properties - assert result_set.command_id == execute_response.command_id - assert result_set.status == CommandState.SUCCEEDED - assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client - assert result_set.buffer_size_bytes == 1000 - assert result_set.arraysize == 100 - assert result_set.description == execute_response.description - - def test_close(self, mock_connection, mock_sea_client, execute_response): - """Test closing a result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set when the connection is closed.""" - mock_connection.open = False - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set._fill_results_buffer() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b8de970db..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,18 +619,14 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 - - # Create a valid operation status - op_status = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=op_status, + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -839,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -885,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - + ( + execute_response, + _, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -923,9 +920,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -953,21 +948,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -987,15 +982,15 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) - # Mock the operation status response - op_state = ttypes.TGetOperationStatusResp( - status=self.okay_status, - operationState=ttypes.TOperationState.FINISHED_STATE, - ) - tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1009,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1024,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1040,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertEqual(is_direct_results, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1053,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1069,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1102,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1157,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1172,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1191,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1206,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1222,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1237,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1249,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1262,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1277,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1291,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1306,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1321,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1335,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -1667,9 +1673,7 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock( - return_value=(Mock(), Mock()) - ) + thrift_backend._results_message_to_execute_response = Mock() # Create a mock response with a real operation handle mock_resp = Mock() @@ -2226,15 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", - return_value=(Mock(), Mock()), + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From b1acc5bffd676c7382be86ad12db011a8ebb38b4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 10:46:57 +0000 Subject: [PATCH 100/204] remove _get_schema_bytes (for now) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 77 +---------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 6d627162d..1d31f2afe 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -301,74 +301,6 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - def _get_schema_bytes(self, sea_response) -> Optional[bytes]: - """ - Extract schema bytes from the SEA response. - - For ARROW format, we need to get the schema bytes from the first chunk. - If the first chunk is not available, we need to get it from the server. - - Args: - sea_response: The response from the SEA API - - Returns: - bytes: The schema bytes or None if not available - """ - import requests - import lz4.frame - - # Check if we have the first chunk in the response - result_data = sea_response.get("result", {}) - external_links = result_data.get("external_links", []) - - if not external_links: - return None - - # Find the first chunk (chunk_index = 0) - first_chunk = None - for link in external_links: - if link.get("chunk_index") == 0: - first_chunk = link - break - - if not first_chunk: - # Try to fetch the first chunk from the server - statement_id = sea_response.get("statement_id") - if not statement_id: - return None - - chunks_response = self.get_chunk_links(statement_id, 0) - if not chunks_response.external_links: - return None - - first_chunk = chunks_response.external_links[0].__dict__ - - # Download the first chunk to get the schema bytes - external_link = first_chunk.get("external_link") - http_headers = first_chunk.get("http_headers", {}) - - if not external_link: - return None - - # Use requests to download the first chunk - http_response = requests.get( - external_link, - headers=http_headers, - verify=self.ssl_options.tls_verify, - ) - - if http_response.status_code != 200: - raise Error(f"Failed to download schema bytes: {http_response.text}") - - # Extract schema bytes from the Arrow file - # The schema is at the beginning of the file - data = http_response.content - if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": - data = lz4.frame.decompress(data) - - # Return the schema bytes - return data - def _results_message_to_execute_response(self, sea_response, command_id): """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -411,13 +343,6 @@ def _results_message_to_execute_response(self, sea_response, command_id): ) description = columns if columns else None - # Extract schema bytes for Arrow format - schema_bytes = None - format = manifest_data.get("format") - if format == "ARROW_STREAM": - # For ARROW format, we need to get the schema bytes - schema_bytes = self._get_schema_bytes(sea_response) - # Check for compression lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" @@ -472,7 +397,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, - arrow_schema_bytes=schema_bytes, + arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW result_format=manifest_data.get("format"), ) From ef2a7eefcf158c6d033664fb5d844c40d07eb65e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 10:48:51 +0000 Subject: [PATCH 101/204] redundant comments Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1d31f2afe..15941d296 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -487,12 +487,11 @@ def execute_command( # Store the command ID in the cursor cursor.active_command_id = command_id - # If async operation, return None and let the client poll for results + # If async operation, return and let the client poll for results if async_op: return None # For synchronous operation, wait for the statement to complete - # Poll until the statement is done status = response.status state = status.state From af8f74e9f3c8bce7d484d312e6f6123d5e770edd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:39:14 +0000 Subject: [PATCH 102/204] remove fetch phase methods Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 15941d296..42903d09d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -87,8 +87,6 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH = BASE_PATH + "statements" STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" - CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks" - CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}" def __init__( self, @@ -278,19 +276,6 @@ def get_default_session_configuration_value(name: str) -> Optional[str]: """ return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper()) - @staticmethod - def is_session_configuration_parameter_supported(name: str) -> bool: - """ - Check if a session configuration parameter is supported. - - Args: - name: The name of the session configuration parameter - - Returns: - True if the parameter is supported, False otherwise - """ - return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP - @staticmethod def get_allowed_session_configurations() -> List[str]: """ From 5540c5c4a8198f5820e275a379110c13d86e0517 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:45:56 +0000 Subject: [PATCH 103/204] reduce code repetititon + introduce gaps after multi line pydocs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 78 +++++-------------- .../sql/backend/sea/models/responses.py | 18 ++--- tests/unit/test_sea_backend.py | 2 +- 3 files changed, 30 insertions(+), 68 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 42903d09d..0e34d2470 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -40,6 +40,11 @@ GetStatementResponse, CreateSessionResponse, ) +from databricks.sql.backend.sea.models.responses import ( + parse_status, + parse_manifest, + parse_result, +) logger = logging.getLogger(__name__) @@ -75,9 +80,6 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. - - This implementation provides session management functionality for SEA, - while other operations raise NotImplementedError. """ # SEA API paths @@ -119,7 +121,6 @@ def __init__( ) self._max_download_threads = kwargs.get("max_download_threads", 10) - self.ssl_options = ssl_options # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -298,16 +299,16 @@ def _results_message_to_execute_response(self, sea_response, command_id): tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response, result data object, and manifest object """ - # Extract status - status_data = sea_response.get("status", {}) - state = CommandState.from_sea_state(status_data.get("state", "")) - # Extract description from manifest + # Parse the response + status = parse_status(sea_response) + manifest_obj = parse_manifest(sea_response) + result_data_obj = parse_result(sea_response) + + # Extract description from manifest schema description = None - manifest_data = sea_response.get("manifest", {}) - schema_data = manifest_data.get("schema", {}) + schema_data = manifest_obj.schema columns_data = schema_data.get("columns", []) - if columns_data: columns = [] for col_data in columns_data: @@ -329,61 +330,17 @@ def _results_message_to_execute_response(self, sea_response, command_id): description = columns if columns else None # Check for compression - lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" - - # Initialize result_data_obj and manifest_obj - result_data_obj = None - manifest_obj = None - - result_data = sea_response.get("result", {}) - if result_data: - # Convert external links - external_links = None - if "external_links" in result_data: - external_links = [] - for link_data in result_data["external_links"]: - external_links.append( - ExternalLink( - external_link=link_data.get("external_link", ""), - expiration=link_data.get("expiration", ""), - chunk_index=link_data.get("chunk_index", 0), - byte_count=link_data.get("byte_count", 0), - row_count=link_data.get("row_count", 0), - row_offset=link_data.get("row_offset", 0), - next_chunk_index=link_data.get("next_chunk_index"), - next_chunk_internal_link=link_data.get( - "next_chunk_internal_link" - ), - http_headers=link_data.get("http_headers", {}), - ) - ) - - # Create the result data object - result_data_obj = ResultData( - data=result_data.get("data_array"), external_links=external_links - ) - - # Create the manifest object - manifest_obj = ResultManifest( - format=manifest_data.get("format", ""), - schema=manifest_data.get("schema", {}), - total_row_count=manifest_data.get("total_row_count", 0), - total_byte_count=manifest_data.get("total_byte_count", 0), - total_chunk_count=manifest_data.get("total_chunk_count", 0), - truncated=manifest_data.get("truncated", False), - chunks=manifest_data.get("chunks"), - result_compression=manifest_data.get("result_compression"), - ) + lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME" execute_response = ExecuteResponse( command_id=command_id, - status=state, + status=status.state, description=description, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW - result_format=manifest_data.get("format"), + result_format=manifest_obj.format, ) return execute_response, result_data_obj, manifest_obj @@ -419,6 +376,7 @@ def execute_command( Returns: ResultSet: A SeaResultSet instance for the executed command """ + if session_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA session ID") @@ -506,6 +464,7 @@ def cancel_command(self, command_id: CommandId) -> None: Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") @@ -528,6 +487,7 @@ def close_command(self, command_id: CommandId) -> None: Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") @@ -553,6 +513,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") @@ -587,6 +548,7 @@ def get_execution_result( Raises: ValueError: If the command ID is invalid """ + if command_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA command ID") diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 0baf27ab2..dae37b1ae 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -18,7 +18,7 @@ ) -def _parse_status(data: Dict[str, Any]) -> StatementStatus: +def parse_status(data: Dict[str, Any]) -> StatementStatus: """Parse status from response data.""" status_data = data.get("status", {}) error = None @@ -40,7 +40,7 @@ def _parse_status(data: Dict[str, Any]) -> StatementStatus: ) -def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: +def parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) @@ -69,7 +69,7 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: ) -def _parse_result(data: Dict[str, Any]) -> ResultData: +def parse_result(data: Dict[str, Any]) -> ResultData: """Parse result data from response data.""" result_data = data.get("result", {}) external_links = None @@ -118,9 +118,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=_parse_status(data), - manifest=_parse_manifest(data), - result=_parse_result(data), + status=parse_status(data), + manifest=parse_manifest(data), + result=parse_result(data), ) @@ -138,9 +138,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=_parse_status(data), - manifest=_parse_manifest(data), - result=_parse_result(data), + status=parse_status(data), + manifest=parse_manifest(data), + result=parse_result(data), ) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2fa362b8e..01424a4d2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -536,7 +536,7 @@ def test_get_execution_result( print(result) # Verify basic properties of the result - assert result.statement_id == "test-statement-123" + assert result.command_id.to_sea_statement_id() == "test-statement-123" assert result.status == CommandState.SUCCEEDED # Verify the HTTP request From efe3881c1b4f7ff31305bcf64a7e39acfd72e590 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:46:53 +0000 Subject: [PATCH 104/204] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 0e34d2470..03080bf5a 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -19,14 +19,9 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import Error, NotSupportedError, ServerOperationError +from databricks.sql.exc import ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions -from databricks.sql.backend.sea.models.base import ( - ResultData, - ExternalLink, - ResultManifest, -) from databricks.sql.backend.sea.models import ( ExecuteStatementRequest, From 36ab59bbdb3e942ede39a2f32844bf3697d15a33 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:51:04 +0000 Subject: [PATCH 105/204] move description extraction to helper func Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 60 ++++++++++++++--------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 03080bf5a..014912c8f 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -282,6 +282,43 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) + def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: + """ + Extract column description from a manifest object. + + Args: + manifest_obj: The ResultManifest object containing schema information + + Returns: + Optional[List]: A list of column tuples or None if no columns are found + """ + + schema_data = manifest_obj.schema + columns_data = schema_data.get("columns", []) + + if not columns_data: + return None + + columns = [] + for col_data in columns_data: + if not isinstance(col_data, dict): + continue + + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + + return columns if columns else None + def _results_message_to_execute_response(self, sea_response, command_id): """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -301,28 +338,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): result_data_obj = parse_result(sea_response) # Extract description from manifest schema - description = None - schema_data = manifest_obj.schema - columns_data = schema_data.get("columns", []) - if columns_data: - columns = [] - for col_data in columns_data: - if not isinstance(col_data, dict): - continue - - # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) - columns.append( - ( - col_data.get("name", ""), # name - col_data.get("type_name", ""), # type_code - None, # display_size (not provided by SEA) - None, # internal_size (not provided by SEA) - col_data.get("precision"), # precision - col_data.get("scale"), # scale - col_data.get("nullable", True), # null_ok - ) - ) - description = columns if columns else None + description = self._extract_description_from_manifest(manifest_obj) # Check for compression lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME" From 1d57c996afff5727c1e66a36e9da82f75777d6f1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 11:52:06 +0000 Subject: [PATCH 106/204] formatting (black) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 014912c8f..1dde8e4dc 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -295,10 +295,10 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: schema_data = manifest_obj.schema columns_data = schema_data.get("columns", []) - + if not columns_data: return None - + columns = [] for col_data in columns_data: if not isinstance(col_data, dict): @@ -316,7 +316,7 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]: col_data.get("nullable", True), # null_ok ) ) - + return columns if columns else None def _results_message_to_execute_response(self, sea_response, command_id): From df6dac2bd84b7e3e2b71f51469571396166a5b34 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:20:49 +0000 Subject: [PATCH 107/204] add more unit tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 299 ++++++++++++++++++++++++++++++++- 1 file changed, 296 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 01424a4d2..e6d293e5f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -9,12 +9,15 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.backend import ( + SeaDatabricksClient, + _filter_session_configuration, +) from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error, NotSupportedError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError class TestSeaBackend: @@ -305,6 +308,32 @@ def test_execute_command_async( assert isinstance(mock_cursor.active_command_id, CommandId) assert mock_cursor.active_command_id.guid == "test-statement-456" + def test_execute_command_async_missing_statement_id( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing an async command that returns no statement ID.""" + # Set up mock response with status but no statement_id + mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}} + + # Call the method and expect an error + with pytest.raises(ServerOperationError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, + ) + + assert "Failed to execute command: No statement ID returned" in str( + excinfo.value + ) + def test_execute_command_with_polling( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): @@ -442,6 +471,32 @@ def test_execute_command_failure( assert "Statement execution did not succeed" in str(excinfo.value) + def test_execute_command_missing_statement_id( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that returns no statement ID.""" + # Set up mock response with status but no statement_id + mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}} + + # Call the method and expect an error + with pytest.raises(ServerOperationError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Failed to execute command: No statement ID returned" in str( + excinfo.value + ) + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): """Test canceling a command.""" # Set up mock response @@ -533,7 +588,6 @@ def test_get_execution_result( # Create a real result set to verify the implementation result = sea_client.get_execution_result(sea_command_id, mock_cursor) - print(result) # Verify basic properties of the result assert result.command_id.to_sea_statement_id() == "test-statement-123" @@ -546,3 +600,242 @@ def test_get_execution_result( assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( "test-statement-123" ) + + def test_get_execution_result_with_invalid_command_id( + self, sea_client, mock_cursor + ): + """Test getting execution result with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.get_execution_result(command_id, mock_cursor) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_max_download_threads_property(self, mock_http_client): + """Test the max_download_threads property.""" + # Test with default value + client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + assert client.max_download_threads == 10 + + # Test with custom value + client = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=5, + ) + assert client.max_download_threads == 5 + + def test_get_default_session_configuration_value(self): + """Test the get_default_session_configuration_value static method.""" + # Test with supported configuration parameter + value = SeaDatabricksClient.get_default_session_configuration_value("ANSI_MODE") + assert value == "true" + + # Test with unsupported configuration parameter + value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert value is None + + # Test with case-insensitive parameter name + value = SeaDatabricksClient.get_default_session_configuration_value("ansi_mode") + assert value == "true" + + def test_get_allowed_session_configurations(self): + """Test the get_allowed_session_configurations static method.""" + configs = SeaDatabricksClient.get_allowed_session_configurations() + assert isinstance(configs, list) + assert len(configs) > 0 + assert "ANSI_MODE" in configs + + def test_extract_description_from_manifest(self, sea_client): + """Test the _extract_description_from_manifest method.""" + # Test with valid manifest containing columns + manifest_obj = MagicMock() + manifest_obj.schema = { + "columns": [ + { + "name": "col1", + "type_name": "STRING", + "precision": 10, + "scale": 2, + "nullable": True, + }, + { + "name": "col2", + "type_name": "INT", + "nullable": False, + }, + ] + } + + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is not None + assert len(description) == 2 + + # Check first column + assert description[0][0] == "col1" # name + assert description[0][1] == "STRING" # type_code + assert description[0][4] == 10 # precision + assert description[0][5] == 2 # scale + assert description[0][6] is True # null_ok + + # Check second column + assert description[1][0] == "col2" # name + assert description[1][1] == "INT" # type_code + assert description[1][6] is False # null_ok + + # Test with manifest containing non-dict column + manifest_obj.schema = {"columns": ["not_a_dict"]} + description = sea_client._extract_description_from_manifest(manifest_obj) + assert ( + description is None + ) # Method returns None when no valid columns are found + + # Test with manifest without columns + manifest_obj.schema = {} + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is None + + def test_cancel_command_with_invalid_command_id(self, sea_client): + """Test canceling a command with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.cancel_command(command_id) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_close_command_with_invalid_command_id(self, sea_client): + """Test closing a command with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.close_command(command_id) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_get_query_state_with_invalid_command_id(self, sea_client): + """Test getting query state with an invalid command ID.""" + # Create a Thrift command ID (not SEA) + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.get_query_state(command_id) + + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_unimplemented_metadata_methods( + self, sea_client, sea_session_id, mock_cursor + ): + """Test that metadata methods raise NotImplementedError.""" + # Test get_catalogs + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) + assert "get_catalogs is not implemented for SEA backend" in str(excinfo.value) + + # Test get_schemas + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) + assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) + + # Test get_schemas with optional parameters + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_schemas( + sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" + ) + assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) + + # Test get_tables + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) + assert "get_tables is not implemented for SEA backend" in str(excinfo.value) + + # Test get_tables with optional parameters + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_tables( + sea_session_id, + 100, + 1000, + mock_cursor, + catalog_name="catalog", + schema_name="schema", + table_name="table", + table_types=["TABLE", "VIEW"], + ) + assert "get_tables is not implemented for SEA backend" in str(excinfo.value) + + # Test get_columns + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) + assert "get_columns is not implemented for SEA backend" in str(excinfo.value) + + # Test get_columns with optional parameters + with pytest.raises(NotImplementedError) as excinfo: + sea_client.get_columns( + sea_session_id, + 100, + 1000, + mock_cursor, + catalog_name="catalog", + schema_name="schema", + table_name="table", + column_name="column", + ) + assert "get_columns is not implemented for SEA backend" in str(excinfo.value) + + def test_execute_command_with_invalid_session_id(self, sea_client, mock_cursor): + """Test executing a command with an invalid session ID type.""" + # Create a Thrift session ID (not SEA) + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + # Call the method and expect an error + with pytest.raises(ValueError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Not a valid SEA session ID" in str(excinfo.value) From ad0e527c6a67ba5d8d89d63655c33f27d2acbe7a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:34:25 +0000 Subject: [PATCH 108/204] streamline unit tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 534 ++++++++++----------------------- 1 file changed, 166 insertions(+), 368 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index e6d293e5f..4b1ec55a3 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -5,7 +5,6 @@ the Databricks SQL connector's SEA backend functionality. """ -import json import pytest from unittest.mock import patch, MagicMock, Mock @@ -13,7 +12,6 @@ SeaDatabricksClient, _filter_session_configuration, ) -from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider @@ -68,10 +66,28 @@ def mock_cursor(self): """Create a mock cursor.""" cursor = Mock() cursor.active_command_id = None + cursor.buffer_size_bytes = 1000 + cursor.arraysize = 100 return cursor - def test_init_extracts_warehouse_id(self, mock_http_client): - """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" + @pytest.fixture + def thrift_session_id(self): + """Create a Thrift session ID (not SEA).""" + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + return SessionId.from_thrift_handle(mock_thrift_handle) + + @pytest.fixture + def thrift_command_id(self): + """Create a Thrift command ID (not SEA).""" + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + return CommandId.from_thrift_handle(mock_thrift_operation_handle) + + def test_initialization(self, mock_http_client): + """Test client initialization and warehouse ID extraction.""" # Test with warehouses format client1 = SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -82,6 +98,7 @@ def test_init_extracts_warehouse_id(self, mock_http_client): ssl_options=SSLOptions(), ) assert client1.warehouse_id == "abc123" + assert client1.max_download_threads == 10 # Default value # Test with endpoints format client2 = SeaDatabricksClient( @@ -94,8 +111,19 @@ def test_init_extracts_warehouse_id(self, mock_http_client): ) assert client2.warehouse_id == "def456" - def test_init_raises_error_for_invalid_http_path(self, mock_http_client): - """Test that the constructor raises an error for invalid HTTP paths.""" + # Test with custom max_download_threads + client3 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=5, + ) + assert client3.max_download_threads == 5 + + # Test with invalid HTTP path with pytest.raises(ValueError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -107,30 +135,21 @@ def test_init_raises_error_for_invalid_http_path(self, mock_http_client): ) assert "Could not extract warehouse ID" in str(excinfo.value) - def test_open_session_basic(self, sea_client, mock_http_client): - """Test the open_session method with minimal parameters.""" - # Set up mock response + def test_session_management(self, sea_client, mock_http_client, thrift_session_id): + """Test session management methods.""" + # Test open_session with minimal parameters mock_http_client._make_request.return_value = {"session_id": "test-session-123"} - - # Call the method session_id = sea_client.open_session(None, None, None) - - # Verify the result assert isinstance(session_id, SessionId) assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-123" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} ) - def test_open_session_with_all_parameters(self, sea_client, mock_http_client): - """Test the open_session method with all parameters.""" - # Set up mock response + # Test open_session with all parameters + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"session_id": "test-session-456"} - - # Call the method with all parameters, including both supported and unsupported configurations session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter @@ -138,16 +157,8 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): } catalog = "test_catalog" schema = "test_schema" - session_id = sea_client.open_session(session_config, catalog, schema) - - # Verify the result - assert isinstance(session_id, SessionId) - assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-456" - - # Verify the HTTP request - only supported parameters should be included - # and keys should be in lowercase expected_data = { "warehouse_id": "abc123", "session_confs": { @@ -157,60 +168,37 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): "catalog": catalog, "schema": schema, } - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="POST", path=sea_client.SESSION_PATH, data=expected_data ) - def test_open_session_error_handling(self, sea_client, mock_http_client): - """Test error handling in the open_session method.""" - # Set up mock response without session_id + # Test open_session error handling + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {} - - # Call the method and expect an error with pytest.raises(Error) as excinfo: sea_client.open_session(None, None, None) - assert "Failed to create session" in str(excinfo.value) - def test_close_session_valid_id(self, sea_client, mock_http_client): - """Test closing a session with a valid session ID.""" - # Create a valid SEA session ID + # Test close_session with valid ID + mock_http_client.reset_mock() session_id = SessionId.from_sea_session_id("test-session-789") - - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method sea_client.close_session(session_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="DELETE", path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) - def test_close_session_invalid_id_type(self, sea_client): - """Test closing a session with an invalid session ID type.""" - # Create a Thrift session ID (not SEA) - mock_thrift_handle = MagicMock() - mock_thrift_handle.sessionId.guid = b"guid" - mock_thrift_handle.sessionId.secret = b"secret" - session_id = SessionId.from_thrift_handle(mock_thrift_handle) - - # Call the method and expect an error + # Test close_session with invalid ID type with pytest.raises(ValueError) as excinfo: - sea_client.close_session(session_id) - + sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) - # Tests for command execution and management - - def test_execute_command_sync( + def test_command_execution_sync( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): - """Test executing a command synchronously.""" - # Set up mock responses + """Test synchronous command execution.""" + # Test synchronous execution execute_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, @@ -230,11 +218,9 @@ def test_execute_command_sync( } mock_http_client._make_request.return_value = execute_response - # Mock the get_execution_result method with patch.object( sea_client, "get_execution_result", return_value="mock_result_set" ) as mock_get_result: - # Call the method result = sea_client.execute_command( operation="SELECT 1", session_id=sea_session_id, @@ -247,38 +233,43 @@ def test_execute_command_sync( async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the result assert result == "mock_result_set" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "warehouse_id" in kwargs["data"] - assert "session_id" in kwargs["data"] - assert "statement" in kwargs["data"] - assert kwargs["data"]["statement"] == "SELECT 1" - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() cmd_id_arg = mock_get_result.call_args[0][0] assert isinstance(cmd_id_arg, CommandId) assert cmd_id_arg.guid == "test-statement-123" - def test_execute_command_async( + # Test with invalid session ID + with pytest.raises(ValueError) as excinfo: + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + thrift_session_id = SessionId.from_thrift_handle(mock_thrift_handle) + + sea_client.execute_command( + operation="SELECT 1", + session_id=thrift_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_command_execution_async( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): - """Test executing a command asynchronously.""" - # Set up mock response + """Test asynchronous command execution.""" + # Test asynchronous execution execute_response = { "statement_id": "test-statement-456", "status": {"state": "PENDING"}, } mock_http_client._make_request.return_value = execute_response - # Call the method result = sea_client.execute_command( operation="SELECT 1", session_id=sea_session_id, @@ -288,34 +279,16 @@ def test_execute_command_async( cursor=mock_cursor, use_cloud_fetch=False, parameters=[], - async_op=True, # Async mode + async_op=True, enforce_embedded_schema_correctness=False, ) - - # Verify the result is None for async operation assert result is None - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.STATEMENT_PATH - assert "wait_timeout" in kwargs["data"] - assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout - - # Verify the command ID was stored in the cursor - assert hasattr(mock_cursor, "active_command_id") assert isinstance(mock_cursor.active_command_id, CommandId) assert mock_cursor.active_command_id.guid == "test-statement-456" - def test_execute_command_async_missing_statement_id( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing an async command that returns no statement ID.""" - # Set up mock response with status but no statement_id + # Test async with missing statement ID + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}} - - # Call the method and expect an error with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", @@ -326,19 +299,18 @@ def test_execute_command_async_missing_statement_id( cursor=mock_cursor, use_cloud_fetch=False, parameters=[], - async_op=True, # Async mode + async_op=True, enforce_embedded_schema_correctness=False, ) - assert "Failed to execute command: No statement ID returned" in str( excinfo.value ) - def test_execute_command_with_polling( + def test_command_execution_advanced( self, sea_client, mock_http_client, mock_cursor, sea_session_id ): - """Test executing a command that requires polling.""" - # Set up mock responses for initial request and polling + """Test advanced command execution scenarios.""" + # Test with polling initial_response = { "statement_id": "test-statement-789", "status": {"state": "RUNNING"}, @@ -349,17 +321,12 @@ def test_execute_command_with_polling( "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, "result": {"data": []}, } - - # Configure mock to return different responses on subsequent calls mock_http_client._make_request.side_effect = [initial_response, poll_response] - # Mock the get_execution_result method with patch.object( sea_client, "get_execution_result", return_value="mock_result_set" ) as mock_get_result: - # Mock time.sleep to avoid actual delays with patch("time.sleep"): - # Call the method result = sea_client.execute_command( operation="SELECT * FROM large_table", session_id=sea_session_id, @@ -372,39 +339,22 @@ def test_execute_command_with_polling( async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the result assert result == "mock_result_set" - # Verify the HTTP requests (initial and poll) - assert mock_http_client._make_request.call_count == 2 - - # Verify get_execution_result was called with the right command ID - mock_get_result.assert_called_once() - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-789" - - def test_execute_command_with_parameters( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command with parameters.""" - # Set up mock response + # Test with parameters + mock_http_client.reset_mock() + mock_http_client._make_request.side_effect = None # Reset side_effect execute_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - - # Create parameter mock param = MagicMock() param.name = "param1" param.value = "value1" param.type = "STRING" - # Mock the get_execution_result method - with patch.object(sea_client, "get_execution_result") as mock_get_result: - # Call the method with parameters + with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( operation="SELECT * FROM table WHERE col = :param1", session_id=sea_session_id, @@ -417,9 +367,6 @@ def test_execute_command_with_parameters( async_op=False, enforce_embedded_schema_correctness=False, ) - - # Verify the HTTP request contains parameters - mock_http_client._make_request.assert_called_once() args, kwargs = mock_http_client._make_request.call_args assert "parameters" in kwargs["data"] assert len(kwargs["data"]["parameters"]) == 1 @@ -427,11 +374,8 @@ def test_execute_command_with_parameters( assert kwargs["data"]["parameters"][0]["value"] == "value1" assert kwargs["data"]["parameters"][0]["type"] == "STRING" - def test_execute_command_failure( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that fails.""" - # Set up mock response for a failed execution + # Test execution failure + mock_http_client.reset_mock() error_response = { "statement_id": "test-statement-123", "status": { @@ -442,43 +386,30 @@ def test_execute_command_failure( }, }, } + mock_http_client._make_request.return_value = error_response - # Configure the mock to return the error response for the initial request - # and then raise an exception when trying to poll (to simulate immediate failure) - mock_http_client._make_request.side_effect = [ - error_response, # Initial response - Error( - "Statement execution did not succeed: Syntax error in SQL" - ), # Will be raised during polling - ] - - # Mock time.sleep to avoid actual delays with patch("time.sleep"): - # Call the method and expect an error - with pytest.raises(Error) as excinfo: - sea_client.execute_command( - operation="SELECT * FROM nonexistent_table", - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Statement execution did not succeed" in str(excinfo.value) - - def test_execute_command_missing_statement_id( - self, sea_client, mock_http_client, mock_cursor, sea_session_id - ): - """Test executing a command that returns no statement ID.""" - # Set up mock response with status but no statement_id + with patch.object( + sea_client, "get_query_state", return_value=CommandState.FAILED + ): + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Statement execution did not succeed" in str(excinfo.value) + + # Test missing statement ID + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}} - - # Call the method and expect an error with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", @@ -492,70 +423,68 @@ def test_execute_command_missing_statement_id( async_op=False, enforce_embedded_schema_correctness=False, ) - assert "Failed to execute command: No statement ID returned" in str( excinfo.value ) - def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): - """Test canceling a command.""" - # Set up mock response + def test_command_management( + self, + sea_client, + mock_http_client, + sea_command_id, + thrift_command_id, + mock_cursor, + ): + """Test command management methods.""" + # Test cancel_command mock_http_client._make_request.return_value = {} - - # Call the method sea_client.cancel_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "POST" - assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + mock_http_client._make_request.assert_called_with( + method="POST", + path=sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, ) - def test_close_command(self, sea_client, mock_http_client, sea_command_id): - """Test closing a command.""" - # Set up mock response - mock_http_client._make_request.return_value = {} + # Test cancel_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.cancel_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) - # Call the method + # Test close_command + mock_http_client.reset_mock() sea_client.close_command(sea_command_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "DELETE" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + mock_http_client._make_request.assert_called_with( + method="DELETE", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, ) - def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): - """Test getting the state of a query.""" - # Set up mock response + # Test close_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.close_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test get_query_state + mock_http_client.reset_mock() mock_http_client._make_request.return_value = { "statement_id": "test-statement-123", "status": {"state": "RUNNING"}, } - - # Call the method state = sea_client.get_query_state(sea_command_id) - - # Verify the result assert state == CommandState.RUNNING - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" + mock_http_client._make_request.assert_called_with( + method="GET", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, ) - def test_get_execution_result( - self, sea_client, mock_http_client, mock_cursor, sea_command_id - ): - """Test getting the result of a command execution.""" - # Set up mock response + # Test get_query_state with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.get_query_state(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) + + # Test get_execution_result + mock_http_client.reset_mock() sea_response = { "statement_id": "test-statement-123", "status": {"state": "SUCCEEDED"}, @@ -585,66 +514,18 @@ def test_get_execution_result( }, } mock_http_client._make_request.return_value = sea_response - - # Create a real result set to verify the implementation result = sea_client.get_execution_result(sea_command_id, mock_cursor) - - # Verify basic properties of the result assert result.command_id.to_sea_statement_id() == "test-statement-123" assert result.status == CommandState.SUCCEEDED - # Verify the HTTP request - mock_http_client._make_request.assert_called_once() - args, kwargs = mock_http_client._make_request.call_args - assert kwargs["method"] == "GET" - assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( - "test-statement-123" - ) - - def test_get_execution_result_with_invalid_command_id( - self, sea_client, mock_cursor - ): - """Test getting execution result with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error + # Test get_execution_result with invalid ID with pytest.raises(ValueError) as excinfo: - sea_client.get_execution_result(command_id, mock_cursor) - + sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) - def test_max_download_threads_property(self, mock_http_client): - """Test the max_download_threads property.""" - # Test with default value - client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - ) - assert client.max_download_threads == 10 - - # Test with custom value - client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=5, - ) - assert client.max_download_threads == 5 - - def test_get_default_session_configuration_value(self): - """Test the get_default_session_configuration_value static method.""" - # Test with supported configuration parameter + def test_utility_methods(self, sea_client): + """Test utility methods.""" + # Test get_default_session_configuration_value value = SeaDatabricksClient.get_default_session_configuration_value("ANSI_MODE") assert value == "true" @@ -658,16 +539,13 @@ def test_get_default_session_configuration_value(self): value = SeaDatabricksClient.get_default_session_configuration_value("ansi_mode") assert value == "true" - def test_get_allowed_session_configurations(self): - """Test the get_allowed_session_configurations static method.""" + # Test get_allowed_session_configurations configs = SeaDatabricksClient.get_allowed_session_configurations() assert isinstance(configs, list) assert len(configs) > 0 assert "ANSI_MODE" in configs - def test_extract_description_from_manifest(self, sea_client): - """Test the _extract_description_from_manifest method.""" - # Test with valid manifest containing columns + # Test _extract_description_from_manifest manifest_obj = MagicMock() manifest_obj.schema = { "columns": [ @@ -689,15 +567,11 @@ def test_extract_description_from_manifest(self, sea_client): description = sea_client._extract_description_from_manifest(manifest_obj) assert description is not None assert len(description) == 2 - - # Check first column assert description[0][0] == "col1" # name assert description[0][1] == "STRING" # type_code assert description[0][4] == 10 # precision assert description[0][5] == 2 # scale assert description[0][6] is True # null_ok - - # Check second column assert description[1][0] == "col2" # name assert description[1][1] == "INT" # type_code assert description[1][6] is False # null_ok @@ -705,85 +579,37 @@ def test_extract_description_from_manifest(self, sea_client): # Test with manifest containing non-dict column manifest_obj.schema = {"columns": ["not_a_dict"]} description = sea_client._extract_description_from_manifest(manifest_obj) - assert ( - description is None - ) # Method returns None when no valid columns are found + assert description is None # Test with manifest without columns manifest_obj.schema = {} description = sea_client._extract_description_from_manifest(manifest_obj) assert description is None - def test_cancel_command_with_invalid_command_id(self, sea_client): - """Test canceling a command with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.cancel_command(command_id) - - assert "Not a valid SEA command ID" in str(excinfo.value) - - def test_close_command_with_invalid_command_id(self, sea_client): - """Test closing a command with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.close_command(command_id) - - assert "Not a valid SEA command ID" in str(excinfo.value) - - def test_get_query_state_with_invalid_command_id(self, sea_client): - """Test getting query state with an invalid command ID.""" - # Create a Thrift command ID (not SEA) - mock_thrift_operation_handle = MagicMock() - mock_thrift_operation_handle.operationId.guid = b"guid" - mock_thrift_operation_handle.operationId.secret = b"secret" - command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.get_query_state(command_id) - - assert "Not a valid SEA command ID" in str(excinfo.value) - def test_unimplemented_metadata_methods( self, sea_client, sea_session_id, mock_cursor ): """Test that metadata methods raise NotImplementedError.""" # Test get_catalogs - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - assert "get_catalogs is not implemented for SEA backend" in str(excinfo.value) # Test get_schemas - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_schemas( sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" ) - assert "get_schemas is not implemented for SEA backend" in str(excinfo.value) # Test get_tables - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - assert "get_tables is not implemented for SEA backend" in str(excinfo.value) # Test get_tables with optional parameters - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_tables( sea_session_id, 100, @@ -794,15 +620,13 @@ def test_unimplemented_metadata_methods( table_name="table", table_types=["TABLE", "VIEW"], ) - assert "get_tables is not implemented for SEA backend" in str(excinfo.value) # Test get_columns - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - assert "get_columns is not implemented for SEA backend" in str(excinfo.value) # Test get_columns with optional parameters - with pytest.raises(NotImplementedError) as excinfo: + with pytest.raises(NotImplementedError): sea_client.get_columns( sea_session_id, 100, @@ -813,29 +637,3 @@ def test_unimplemented_metadata_methods( table_name="table", column_name="column", ) - assert "get_columns is not implemented for SEA backend" in str(excinfo.value) - - def test_execute_command_with_invalid_session_id(self, sea_client, mock_cursor): - """Test executing a command with an invalid session ID type.""" - # Create a Thrift session ID (not SEA) - mock_thrift_handle = MagicMock() - mock_thrift_handle.sessionId.guid = b"guid" - mock_thrift_handle.sessionId.secret = b"secret" - session_id = SessionId.from_thrift_handle(mock_thrift_handle) - - # Call the method and expect an error - with pytest.raises(ValueError) as excinfo: - sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, - max_rows=100, - max_bytes=1000, - lz4_compression=False, - cursor=mock_cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - ) - - assert "Not a valid SEA session ID" in str(excinfo.value) From ed446a0fe240d27626fa70657005f7f8ce065766 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:37:24 +0000 Subject: [PATCH 109/204] test getting the list of allowed configurations Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 4b1ec55a3..1d16763be 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -545,6 +545,20 @@ def test_utility_methods(self, sea_client): assert len(configs) > 0 assert "ANSI_MODE" in configs + # Test getting the list of allowed configurations with specific keys + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", + } + assert set(allowed_configs) == expected_keys + # Test _extract_description_from_manifest manifest_obj = MagicMock() manifest_obj.schema = { From 38e4b5c25517146acb90ae962a02cbb6a5c3b98e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:38:50 +0000 Subject: [PATCH 110/204] reduce diff Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1dde8e4dc..cf10c904a 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -604,8 +604,8 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - """Get available catalogs by executing 'SHOW CATALOGS'.""" - raise NotImplementedError("get_catalogs is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") def get_schemas( self, @@ -616,8 +616,8 @@ def get_schemas( catalog_name: Optional[str] = None, schema_name: Optional[str] = None, ) -> "ResultSet": - """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" - raise NotImplementedError("get_schemas is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_schemas is not yet implemented for SEA backend") def get_tables( self, @@ -630,8 +630,8 @@ def get_tables( table_name: Optional[str] = None, table_types: Optional[List[str]] = None, ) -> "ResultSet": - """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_tables is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_tables is not yet implemented for SEA backend") def get_columns( self, @@ -644,5 +644,5 @@ def get_columns( table_name: Optional[str] = None, column_name: Optional[str] = None, ) -> "ResultSet": - """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" - raise NotImplementedError("get_columns is not implemented for SEA backend") + """Not implemented yet.""" + raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 94879c017ce2db6e289c46c47b51a7296c0db678 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 12 Jun 2025 12:39:28 +0000 Subject: [PATCH 111/204] reduce diff Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index cf10c904a..e892e10e7 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -603,7 +603,7 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") @@ -615,7 +615,7 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_schemas is not yet implemented for SEA backend") @@ -629,7 +629,7 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_tables is not yet implemented for SEA backend") @@ -643,6 +643,6 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": + ): """Not implemented yet.""" raise NotImplementedError("get_columns is not yet implemented for SEA backend") From 18099560157074870d83f1a43146c1687962a92d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 13 Jun 2025 03:38:43 +0000 Subject: [PATCH 112/204] house constants in enums for readability and immutability Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 20 ++++++++++--- .../sql/backend/sea/utils/constants.py | 29 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index e892e10e7..4602db3b7 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -5,6 +5,10 @@ from databricks.sql.backend.sea.utils.constants import ( ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, + ResultFormat, + ResultDisposition, + ResultCompression, + WaitTimeout, ) if TYPE_CHECKING: @@ -405,9 +409,17 @@ def execute_command( ) ) - format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" - disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" - result_compression = "LZ4_FRAME" if lz4_compression else None + format = ( + ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY + ).value + disposition = ( + ResultDisposition.EXTERNAL_LINKS + if use_cloud_fetch + else ResultDisposition.INLINE + ).value + result_compression = ( + ResultCompression.LZ4_FRAME if lz4_compression else ResultCompression.NONE + ).value request = ExecuteStatementRequest( warehouse_id=self.warehouse_id, @@ -415,7 +427,7 @@ def execute_command( statement=operation, disposition=disposition, format=format, - wait_timeout="0s" if async_op else "10s", + wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value, on_wait_timeout="CONTINUE", row_limit=max_rows, parameters=sea_parameters if sea_parameters else None, diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 9160ef6ad..cd5cc657d 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -3,6 +3,7 @@ """ from typing import Dict +from enum import Enum # from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = { @@ -15,3 +16,31 @@ "TIMEZONE": "UTC", "USE_CACHED_RESULT": "true", } + + +class ResultFormat(Enum): + """Enum for result format values.""" + + ARROW_STREAM = "ARROW_STREAM" + JSON_ARRAY = "JSON_ARRAY" + + +class ResultDisposition(Enum): + """Enum for result disposition values.""" + + EXTERNAL_LINKS = "EXTERNAL_LINKS" + INLINE = "INLINE" + + +class ResultCompression(Enum): + """Enum for result compression values.""" + + LZ4_FRAME = "LZ4_FRAME" + NONE = None + + +class WaitTimeout(Enum): + """Enum for wait timeout values.""" + + ASYNC = "0s" + SYNC = "10s" From da5260cd82ffcdd31ed6393d0d0101c41fc7fcc7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 13 Jun 2025 03:39:16 +0000 Subject: [PATCH 113/204] add note on hybrid disposition Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index cd5cc657d..7481a90db 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -28,6 +28,7 @@ class ResultFormat(Enum): class ResultDisposition(Enum): """Enum for result disposition values.""" + # TODO: add support for hybrid disposition EXTERNAL_LINKS = "EXTERNAL_LINKS" INLINE = "INLINE" From 0385ffb03a3684d5a00f74eed32610cacbc34331 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 16 Jun 2025 06:31:29 +0000 Subject: [PATCH 114/204] remove redundant note on arrow_schema_bytes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 4602db3b7..b829f0644 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -354,7 +354,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, - arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW + arrow_schema_bytes=None, result_format=manifest_obj.format, ) From 23963fc931c809e9e455f966a5d8c4906d49a169 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 16 Jun 2025 10:15:59 +0000 Subject: [PATCH 115/204] align SeaResultSet with ext-links-sea Signed-off-by: varun-edachali-dbx --- .../tests/test_sea_async_query.py | 15 +- .../experimental/tests/test_sea_sync_query.py | 13 +- src/databricks/sql/result_set.py | 344 ++++++++---------- tests/unit/test_sea_result_set.py | 308 +++------------- 4 files changed, 229 insertions(+), 451 deletions(-) diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 35135b64a..cfcbe307f 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -69,8 +69,12 @@ def test_sea_async_query_with_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - rows = cursor.fetchall() - logger.info(f"Retrieved rows: {rows}") + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + logger.info( "Successfully retrieved asynchronous query results with cloud fetch enabled" ) @@ -150,8 +154,11 @@ def test_sea_async_query_without_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - rows = cursor.fetchall() - logger.info(f"Retrieved rows: {rows}") + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + logger.info( "Successfully retrieved asynchronous query results with cloud fetch disabled" ) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 0f12445d1..a60410ba4 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -55,8 +55,10 @@ def test_sea_sync_query_with_cloud_fetch(): cursor.execute( "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - rows = cursor.fetchall() - logger.info(f"Retrieved rows: {rows}") + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") # Close resources cursor.close() @@ -121,10 +123,11 @@ def test_sea_sync_query_without_cloud_fetch(): cursor.execute( "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" ) - logger.info("Query executed successfully with cloud fetch disabled") - rows = cursor.fetchall() - logger.info(f"Retrieved rows: {rows}") + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") # Close resources cursor.close() diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index bd5897fb7..d100e3c72 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -92,6 +92,44 @@ def __iter__(self): else: break + def _convert_arrow_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + + if self.connection.disable_pandas is True: + return [ + ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) + ] + + # Need to use nullable types, as otherwise type can change when there are missing values. + # See https://arrow.apache.org/docs/python/pandas.html#nullable-types + # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + # Need to rename columns, as the to_pandas function cannot handle duplicate column names + table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) + df = table_renamed.to_pandas( + types_mapper=dtype_mapping.get, + date_as_object=True, + timestamp_as_object=True, + ) + + res = df.to_numpy(na_value=None, dtype="object") + return [ResultRow(*v) for v in res] + @property def rownumber(self): return self._next_row_index @@ -101,12 +139,6 @@ def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" return self._is_staging_operation - # Define abstract methods that concrete implementations must implement - @abstractmethod - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - pass - @abstractmethod def fetchone(self) -> Optional[Row]: """Fetch the next row of a query result set.""" @@ -251,44 +283,6 @@ def _convert_columnar_table(self, table): return result - def _convert_arrow_table(self, table): - column_names = [c[0] for c in self.description] - ResultRow = Row(*column_names) - - if self.connection.disable_pandas is True: - return [ - ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns()) - ] - - # Need to use nullable types, as otherwise type can change when there are missing values. - # See https://arrow.apache.org/docs/python/pandas.html#nullable-types - # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { - pyarrow.int8(): pandas.Int8Dtype(), - pyarrow.int16(): pandas.Int16Dtype(), - pyarrow.int32(): pandas.Int32Dtype(), - pyarrow.int64(): pandas.Int64Dtype(), - pyarrow.uint8(): pandas.UInt8Dtype(), - pyarrow.uint16(): pandas.UInt16Dtype(), - pyarrow.uint32(): pandas.UInt32Dtype(), - pyarrow.uint64(): pandas.UInt64Dtype(), - pyarrow.bool_(): pandas.BooleanDtype(), - pyarrow.float32(): pandas.Float32Dtype(), - pyarrow.float64(): pandas.Float64Dtype(), - pyarrow.string(): pandas.StringDtype(), - } - - # Need to rename columns, as the to_pandas function cannot handle duplicate column names - table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) - - res = df.to_numpy(na_value=None, dtype="object") - return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": """ Function to merge / combining the columnar results into a single result @@ -458,8 +452,8 @@ def __init__( sea_client: "SeaDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, - result_data: Optional[ResultData] = None, - manifest: Optional[ResultManifest] = None, + result_data: Optional["ResultData"] = None, + manifest: Optional["ResultManifest"] = None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -474,18 +468,20 @@ def __init__( manifest: Manifest from SEA response (optional) """ + results_queue = None if result_data: - queue = SeaResultSetQueueFactory.build_queue( - sea_result_data=result_data, - manifest=manifest, - statement_id=execute_response.command_id.to_sea_statement_id(), + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + str(execute_response.command_id.to_sea_statement_id()), description=execute_response.description, - schema_bytes=execute_response.arrow_schema_bytes, + max_download_threads=sea_client.max_download_threads, + ssl_options=sea_client.ssl_options, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, ) - else: - logger.warning("No result data provided for SEA result set") - queue = JsonQueue([]) + # Call parent constructor with common attributes super().__init__( connection=connection, backend=sea_client, @@ -494,20 +490,20 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - results_queue=queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", ) - def _convert_to_row_objects(self, rows): + # Initialize queue for result data if not provided + self.results = results_queue or JsonQueue([]) + + def _convert_json_rows(self, rows): """ Convert raw data rows to Row objects with named columns based on description. - Args: rows: List of raw data rows - Returns: List of Row objects with named columns """ @@ -518,170 +514,140 @@ def _convert_to_row_objects(self, rows): ResultRow = Row(*column_names) return [ResultRow(*row) for row in rows] - def _fill_results_buffer(self): - """Fill the results buffer from the backend.""" - return None - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ - rows = self.results.next_n_rows(1) - if not rows: - return None + Fetch the next set of rows as an Arrow table. - # Convert to Row object - converted_rows = self._convert_to_row_objects(rows) - return converted_rows[0] if converted_rows else None + Args: + size: Number of rows to fetch - def fetchmany(self, size: Optional[int] = None) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. + Returns: + PyArrow Table containing the fetched rows - An empty sequence is returned when no more rows are available. + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative """ - if size is None: - size = self.arraysize - if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - rows = self.results.next_n_rows(size) - self._next_row_index += len(rows) + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while n_remaining_rows > 0: + partial_results = self.results.next_n_rows(n_remaining_rows) + results = pyarrow.concat_tables([results, partial_results]) + n_remaining_rows = n_remaining_rows - partial_results.num_rows + self._next_row_index += partial_results.num_rows - # Convert to Row objects - return self._convert_to_row_objects(rows) + return results - def fetchall(self) -> List[Row]: + def fetchall_arrow(self) -> "pyarrow.Table": """ - Fetch all (remaining) rows of a query result, returning them as a list of rows. + Fetch all remaining rows as an Arrow table. + + Returns: + PyArrow Table containing all remaining rows + + Raises: + ImportError: If PyArrow is not installed """ + results = self.results.remaining_rows() + self._next_row_index += results.num_rows - rows = self.results.remaining_rows() - self._next_row_index += len(rows) + # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table + # Valid only for metadata commands result set + if isinstance(results, ColumnTable) and pyarrow: + data = { + name: col + for name, col in zip(results.column_names, results.column_table) + } + return pyarrow.Table.from_pydict(data) - # Convert to Row objects - return self._convert_to_row_objects(rows) + return results - def _create_empty_arrow_table(self) -> Any: + def fetchmany_json(self, size: int): """ - Create an empty PyArrow table with the schema from the result set. + Fetch the next set of rows as a columnar table. + + Args: + size: Number of rows to fetch Returns: - An empty PyArrow table with the correct schema. - """ - import pyarrow + Columnar table containing the fetched rows - # Try to use schema bytes if available - if self._arrow_schema_bytes: - schema = pyarrow.ipc.read_schema( - pyarrow.BufferReader(self._arrow_schema_bytes) - ) - return pyarrow.Table.from_pydict( - {name: [] for name in schema.names}, schema=schema - ) + Raises: + ValueError: If size is negative + """ + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - # Fall back to creating schema from description - if self.description: - # Map SQL types to PyArrow types - type_map = { - "boolean": pyarrow.bool_(), - "tinyint": pyarrow.int8(), - "smallint": pyarrow.int16(), - "int": pyarrow.int32(), - "bigint": pyarrow.int64(), - "float": pyarrow.float32(), - "double": pyarrow.float64(), - "string": pyarrow.string(), - "binary": pyarrow.binary(), - "timestamp": pyarrow.timestamp("us"), - "date": pyarrow.date32(), - "decimal": pyarrow.decimal128(38, 18), # Default precision and scale - } + results = self.results.next_n_rows(size) + n_remaining_rows = size - len(results) + self._next_row_index += len(results) - fields = [] - for col_desc in self.description: - col_name = col_desc[0] - col_type = col_desc[1].lower() if col_desc[1] else "string" - - # Handle decimal with precision and scale - if ( - col_type == "decimal" - and col_desc[4] is not None - and col_desc[5] is not None - ): - arrow_type = pyarrow.decimal128(col_desc[4], col_desc[5]) - else: - arrow_type = type_map.get(col_type, pyarrow.string()) - - fields.append(pyarrow.field(col_name, arrow_type)) - - schema = pyarrow.schema(fields) - return pyarrow.Table.from_pydict( - {name: [] for name in schema.names}, schema=schema - ) + while n_remaining_rows > 0: + partial_results = self.results.next_n_rows(n_remaining_rows) + results = results + partial_results + n_remaining_rows = n_remaining_rows - len(partial_results) + self._next_row_index += len(partial_results) - # If no schema information is available, return an empty table - return pyarrow.Table.from_pydict({}) + return results - def _convert_rows_to_arrow_table(self, rows: List[Row]) -> Any: + def fetchall_json(self): """ - Convert a list of Row objects to a PyArrow table. - - Args: - rows: List of Row objects to convert. + Fetch all remaining rows as a columnar table. Returns: - PyArrow table containing the data from the rows. + Columnar table containing all remaining rows """ - import pyarrow - - if not rows: - return self._create_empty_arrow_table() + results = self.results.remaining_rows() + self._next_row_index += len(results) - # Extract column names from description - if self.description: - column_names = [col[0] for col in self.description] - else: - # If no description, use the attribute names from the first row - column_names = rows[0]._fields + return results - # Convert rows to columns - columns: dict[str, list] = {name: [] for name in column_names} + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. - for row in rows: - for i, name in enumerate(column_names): - if hasattr(row, "_asdict"): # If it's a Row object - columns[name].append(row[i]) - else: # If it's a raw list - columns[name].append(row[i]) + Returns: + A single Row object or None if no more rows are available + """ + if isinstance(self.results, JsonQueue): + res = self._convert_json_rows(self.fetchmany_json(1)) + else: + raise NotImplementedError("fetchone only supported for JSON data") - # Create PyArrow table - return pyarrow.Table.from_pydict(columns) + return res[0] if res else None - def fetchmany_arrow(self, size: int) -> Any: - """Fetch the next set of rows as an Arrow table.""" - if not pyarrow: - raise ImportError("PyArrow is required for Arrow support") + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. - rows = self.fetchmany(size) - if not rows: - # Return empty Arrow table with schema - return self._create_empty_arrow_table() + Args: + size: Number of rows to fetch (defaults to arraysize if None) - # Convert rows to Arrow table - return self._convert_rows_to_arrow_table(rows) + Returns: + List of Row objects - def fetchall_arrow(self) -> Any: - """Fetch all remaining rows as an Arrow table.""" - if not pyarrow: - raise ImportError("PyArrow is required for Arrow support") + Raises: + ValueError: If size is negative + """ + if isinstance(self.results, JsonQueue): + return self._convert_json_rows(self.fetchmany_json(size)) + else: + raise NotImplementedError("fetchmany only supported for JSON data") - rows = self.fetchall() - if not rows: - # Return empty Arrow table with schema - return self._create_empty_arrow_table() + def fetchall(self) -> List[Row]: + """ + Fetch all remaining rows of a query result, returning them as a list of rows. - # Convert rows to Arrow table - return self._convert_rows_to_arrow_table(rows) + Returns: + List of Row objects containing all remaining rows + """ + if isinstance(self.results, JsonQueue): + return self._convert_json_rows(self.fetchall_json()) + else: + raise NotImplementedError("fetchall only supported for JSON data") diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 85ad60501..846e9e007 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -9,6 +9,7 @@ from unittest.mock import patch, MagicMock, Mock from databricks.sql.result_set import SeaResultSet +from databricks.sql.utils import JsonQueue from databricks.sql.backend.types import CommandId, CommandState, BackendType @@ -34,12 +35,12 @@ def execute_response(self): mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") mock_response.status = CommandState.SUCCEEDED mock_response.has_been_closed_server_side = False - mock_response.has_more_rows = False - mock_response.results_queue = None mock_response.description = [ ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False + mock_response.lz4_compressed = False + mock_response.arrow_schema_bytes = b"" return mock_response def test_init_with_execute_response( @@ -124,9 +125,9 @@ def test_close_when_connection_closed( assert result_set.status == CommandState.CLOSED @pytest.fixture - def mock_results_queue(self): - """Create a mock results queue.""" - mock_queue = Mock() + def mock_json_queue(self): + """Create a mock JsonQueue.""" + mock_queue = Mock(spec=JsonQueue) mock_queue.next_n_rows.return_value = [["value1", 123], ["value2", 456]] mock_queue.remaining_rows.return_value = [ ["value1", 123], @@ -135,85 +136,8 @@ def mock_results_queue(self): ] return mock_queue - def test_fill_results_buffer( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer returns None.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - assert result_set._fill_results_buffer() is None - - def test_convert_to_row_objects( - self, mock_connection, mock_sea_client, execute_response - ): - """Test converting raw data rows to Row objects.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test with empty description - result_set.description = None - rows = [["value1", 123], ["value2", 456]] - converted_rows = result_set._convert_to_row_objects(rows) - assert converted_rows == rows - - # Test with empty rows - result_set.description = [ - ("col1", "STRING", None, None, None, None, None), - ("col2", "INT", None, None, None, None, None), - ] - assert result_set._convert_to_row_objects([]) == [] - - # Test with description and rows - rows = [["value1", 123], ["value2", 456]] - converted_rows = result_set._convert_to_row_objects(rows) - assert len(converted_rows) == 2 - assert converted_rows[0].col1 == "value1" - assert converted_rows[0].col2 == 123 - assert converted_rows[1].col1 == "value2" - assert converted_rows[1].col2 == 456 - - def test_fetchone( - self, mock_connection, mock_sea_client, execute_response, mock_results_queue - ): - """Test fetchone method.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.results = mock_results_queue - result_set.description = [ - ("col1", "STRING", None, None, None, None, None), - ("col2", "INT", None, None, None, None, None), - ] - - # Mock the next_n_rows to return a single row - mock_results_queue.next_n_rows.return_value = [["value1", 123]] - - row = result_set.fetchone() - assert row is not None - assert row.col1 == "value1" - assert row.col2 == 123 - - # Test when no rows are available - mock_results_queue.next_n_rows.return_value = [] - assert result_set.fetchone() is None - def test_fetchmany( - self, mock_connection, mock_sea_client, execute_response, mock_results_queue + self, mock_connection, mock_sea_client, execute_response, mock_json_queue ): """Test fetchmany method.""" result_set = SeaResultSet( @@ -223,7 +147,7 @@ def test_fetchmany( buffer_size_bytes=1000, arraysize=100, ) - result_set.results = mock_results_queue + result_set.results = mock_json_queue result_set.description = [ ("col1", "STRING", None, None, None, None, None), ("col2", "INT", None, None, None, None, None), @@ -239,9 +163,9 @@ def test_fetchmany( # Test with default size (arraysize) result_set.arraysize = 2 - mock_results_queue.next_n_rows.reset_mock() - rows = result_set.fetchmany() - mock_results_queue.next_n_rows.assert_called_with(2) + mock_json_queue.next_n_rows.reset_mock() + rows = result_set.fetchmany(result_set.arraysize) + mock_json_queue.next_n_rows.assert_called_with(2) # Test with negative size with pytest.raises( @@ -250,7 +174,7 @@ def test_fetchmany( result_set.fetchmany(-1) def test_fetchall( - self, mock_connection, mock_sea_client, execute_response, mock_results_queue + self, mock_connection, mock_sea_client, execute_response, mock_json_queue ): """Test fetchall method.""" result_set = SeaResultSet( @@ -260,7 +184,7 @@ def test_fetchall( buffer_size_bytes=1000, arraysize=100, ) - result_set.results = mock_results_queue + result_set.results = mock_json_queue result_set.description = [ ("col1", "STRING", None, None, None, None, None), ("col2", "INT", None, None, None, None, None), @@ -278,16 +202,10 @@ def test_fetchall( # Verify _next_row_index is updated assert result_set._next_row_index == 3 - @pytest.mark.skipif( - pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, - reason="PyArrow is not installed", - ) - def test_create_empty_arrow_table( - self, mock_connection, mock_sea_client, execute_response, monkeypatch + def test_fetchmany_json( + self, mock_connection, mock_sea_client, execute_response, mock_json_queue ): - """Test creating an empty Arrow table with schema.""" - import pyarrow - + """Test fetchmany_json method.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -295,47 +213,22 @@ def test_create_empty_arrow_table( buffer_size_bytes=1000, arraysize=100, ) + result_set.results = mock_json_queue - # Mock _arrow_schema_bytes to return a valid schema - schema = pyarrow.schema( - [ - pyarrow.field("col1", pyarrow.string()), - pyarrow.field("col2", pyarrow.int32()), - ] - ) - schema_bytes = schema.serialize().to_pybytes() - monkeypatch.setattr(result_set, "_arrow_schema_bytes", schema_bytes) - - # Test with schema bytes - empty_table = result_set._create_empty_arrow_table() - assert isinstance(empty_table, pyarrow.Table) - assert empty_table.num_rows == 0 - assert empty_table.num_columns == 2 - assert empty_table.schema.names == ["col1", "col2"] - - # Test without schema bytes but with description - monkeypatch.setattr(result_set, "_arrow_schema_bytes", b"") - result_set.description = [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), - ] + # Test with specific size + result_set.fetchmany_json(2) + mock_json_queue.next_n_rows.assert_called_with(2) - empty_table = result_set._create_empty_arrow_table() - assert isinstance(empty_table, pyarrow.Table) - assert empty_table.num_rows == 0 - assert empty_table.num_columns == 2 - assert empty_table.schema.names == ["col1", "col2"] - - @pytest.mark.skipif( - pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, - reason="PyArrow is not installed", - ) - def test_convert_rows_to_arrow_table( - self, mock_connection, mock_sea_client, execute_response - ): - """Test converting rows to Arrow table.""" - import pyarrow + # Test with negative size + with pytest.raises( + ValueError, match="size argument for fetchmany is -1 but must be >= 0" + ): + result_set.fetchmany_json(-1) + def test_fetchall_json( + self, mock_connection, mock_sea_client, execute_response, mock_json_queue + ): + """Test fetchall_json method.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -343,34 +236,16 @@ def test_convert_rows_to_arrow_table( buffer_size_bytes=1000, arraysize=100, ) + result_set.results = mock_json_queue - result_set.description = [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), - ] - - rows = [["value1", 123], ["value2", 456], ["value3", 789]] - - arrow_table = result_set._convert_rows_to_arrow_table(rows) - assert isinstance(arrow_table, pyarrow.Table) - assert arrow_table.num_rows == 3 - assert arrow_table.num_columns == 2 - assert arrow_table.schema.names == ["col1", "col2"] + # Test fetchall_json + result_set.fetchall_json() + mock_json_queue.remaining_rows.assert_called_once() - # Check data - assert arrow_table.column(0).to_pylist() == ["value1", "value2", "value3"] - assert arrow_table.column(1).to_pylist() == [123, 456, 789] - - @pytest.mark.skipif( - pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, - reason="PyArrow is not installed", - ) - def test_fetchmany_arrow( - self, mock_connection, mock_sea_client, execute_response, mock_results_queue + def test_convert_json_rows( + self, mock_connection, mock_sea_client, execute_response ): - """Test fetchmany_arrow method.""" - import pyarrow - + """Test _convert_json_rows method.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -378,103 +253,30 @@ def test_fetchmany_arrow( buffer_size_bytes=1000, arraysize=100, ) - result_set.results = mock_results_queue - result_set.description = [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), - ] - - # Test with data - arrow_table = result_set.fetchmany_arrow(2) - assert isinstance(arrow_table, pyarrow.Table) - assert arrow_table.num_rows == 2 - assert arrow_table.column(0).to_pylist() == ["value1", "value2"] - assert arrow_table.column(1).to_pylist() == [123, 456] - - # Test with no data - mock_results_queue.next_n_rows.return_value = [] - - # Mock _create_empty_arrow_table to return an empty table - result_set._create_empty_arrow_table = Mock() - empty_table = pyarrow.Table.from_pydict({"col1": [], "col2": []}) - result_set._create_empty_arrow_table.return_value = empty_table - - arrow_table = result_set.fetchmany_arrow(2) - assert isinstance(arrow_table, pyarrow.Table) - assert arrow_table.num_rows == 0 - result_set._create_empty_arrow_table.assert_called_once() - - @pytest.mark.skipif( - pytest.importorskip("pyarrow", reason="PyArrow is not installed") is None, - reason="PyArrow is not installed", - ) - def test_fetchall_arrow( - self, mock_connection, mock_sea_client, execute_response, mock_results_queue - ): - """Test fetchall_arrow method.""" - import pyarrow - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.results = mock_results_queue + # Test with description and rows result_set.description = [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), ] + rows = [["value1", 123], ["value2", 456]] + converted_rows = result_set._convert_json_rows(rows) - # Test with data - arrow_table = result_set.fetchall_arrow() - assert isinstance(arrow_table, pyarrow.Table) - assert arrow_table.num_rows == 3 - assert arrow_table.column(0).to_pylist() == ["value1", "value2", "value3"] - assert arrow_table.column(1).to_pylist() == [123, 456, 789] - - # Test with no data - mock_results_queue.remaining_rows.return_value = [] - - # Mock _create_empty_arrow_table to return an empty table - result_set._create_empty_arrow_table = Mock() - empty_table = pyarrow.Table.from_pydict({"col1": [], "col2": []}) - result_set._create_empty_arrow_table.return_value = empty_table + assert len(converted_rows) == 2 + assert converted_rows[0].col1 == "value1" + assert converted_rows[0].col2 == 123 + assert converted_rows[1].col1 == "value2" + assert converted_rows[1].col2 == 456 - arrow_table = result_set.fetchall_arrow() - assert isinstance(arrow_table, pyarrow.Table) - assert arrow_table.num_rows == 0 - result_set._create_empty_arrow_table.assert_called_once() + # Test with no description + result_set.description = None + converted_rows = result_set._convert_json_rows(rows) + assert converted_rows == rows - def test_iteration_protocol( - self, mock_connection, mock_sea_client, execute_response, mock_results_queue - ): - """Test iteration protocol using fetchone.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.results = mock_results_queue + # Test with empty rows result_set.description = [ - ("col1", "string", None, None, None, None, None), - ("col2", "int", None, None, None, None, None), - ] - - # Set up mock to return different values on each call - mock_results_queue.next_n_rows.side_effect = [ - [["value1", 123]], - [["value2", 456]], - [], # End of data + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), ] - - # Test iteration - rows = list(result_set) - assert len(rows) == 2 - assert rows[0].col1 == "value1" - assert rows[0].col2 == 123 - assert rows[1].col1 == "value2" - assert rows[1].col2 == 456 + converted_rows = result_set._convert_json_rows([]) + assert converted_rows == [] From dd43715207a3e040aa5cf0bf0858c30bed82b91e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 16 Jun 2025 10:26:14 +0000 Subject: [PATCH 116/204] remove redundant methods Signed-off-by: varun-edachali-dbx --- poetry.lock | 265 ++++++++++++++++++++++++++++-- pyproject.toml | 3 + src/databricks/sql/result_set.py | 78 +-------- tests/unit/test_sea_result_set.py | 167 ++++++++++++++++++- 4 files changed, 425 insertions(+), 88 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1bc396c9d..12d984f22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -6,6 +6,7 @@ version = "3.2.4" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, @@ -20,6 +21,7 @@ version = "22.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, @@ -55,6 +57,7 @@ version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -66,6 +69,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -167,6 +171,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -181,17 +186,193 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.6.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version < \"3.10\"" +files = [ + {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, + {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, + {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, + {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, + {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, + {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, + {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, + {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, + {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, + {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, + {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, + {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, + {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] + +[[package]] +name = "coverage" +version = "7.9.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "coverage-7.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cc94d7c5e8423920787c33d811c0be67b7be83c705f001f7180c7b186dcf10ca"}, + {file = "coverage-7.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16aa0830d0c08a2c40c264cef801db8bc4fc0e1892782e45bcacbd5889270509"}, + {file = "coverage-7.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf95981b126f23db63e9dbe4cf65bd71f9a6305696fa5e2262693bc4e2183f5b"}, + {file = "coverage-7.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f05031cf21699785cd47cb7485f67df619e7bcdae38e0fde40d23d3d0210d3c3"}, + {file = "coverage-7.9.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4fbcab8764dc072cb651a4bcda4d11fb5658a1d8d68842a862a6610bd8cfa3"}, + {file = "coverage-7.9.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0f16649a7330ec307942ed27d06ee7e7a38417144620bb3d6e9a18ded8a2d3e5"}, + {file = "coverage-7.9.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cea0a27a89e6432705fffc178064503508e3c0184b4f061700e771a09de58187"}, + {file = "coverage-7.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e980b53a959fa53b6f05343afbd1e6f44a23ed6c23c4b4c56c6662bbb40c82ce"}, + {file = "coverage-7.9.1-cp310-cp310-win32.whl", hash = "sha256:70760b4c5560be6ca70d11f8988ee6542b003f982b32f83d5ac0b72476607b70"}, + {file = "coverage-7.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a66e8f628b71f78c0e0342003d53b53101ba4e00ea8dabb799d9dba0abbbcebe"}, + {file = "coverage-7.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95c765060e65c692da2d2f51a9499c5e9f5cf5453aeaf1420e3fc847cc060582"}, + {file = "coverage-7.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba383dc6afd5ec5b7a0d0c23d38895db0e15bcba7fb0fa8901f245267ac30d86"}, + {file = "coverage-7.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37ae0383f13cbdcf1e5e7014489b0d71cc0106458878ccde52e8a12ced4298ed"}, + {file = "coverage-7.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:69aa417a030bf11ec46149636314c24c8d60fadb12fc0ee8f10fda0d918c879d"}, + {file = "coverage-7.9.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a4be2a28656afe279b34d4f91c3e26eccf2f85500d4a4ff0b1f8b54bf807338"}, + {file = "coverage-7.9.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:382e7ddd5289f140259b610e5f5c58f713d025cb2f66d0eb17e68d0a94278875"}, + {file = "coverage-7.9.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e5532482344186c543c37bfad0ee6069e8ae4fc38d073b8bc836fc8f03c9e250"}, + {file = "coverage-7.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a39d18b3f50cc121d0ce3838d32d58bd1d15dab89c910358ebefc3665712256c"}, + {file = "coverage-7.9.1-cp311-cp311-win32.whl", hash = "sha256:dd24bd8d77c98557880def750782df77ab2b6885a18483dc8588792247174b32"}, + {file = "coverage-7.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:6b55ad10a35a21b8015eabddc9ba31eb590f54adc9cd39bcf09ff5349fd52125"}, + {file = "coverage-7.9.1-cp311-cp311-win_arm64.whl", hash = "sha256:6ad935f0016be24c0e97fc8c40c465f9c4b85cbbe6eac48934c0dc4d2568321e"}, + {file = "coverage-7.9.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8de12b4b87c20de895f10567639c0797b621b22897b0af3ce4b4e204a743626"}, + {file = "coverage-7.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5add197315a054e92cee1b5f686a2bcba60c4c3e66ee3de77ace6c867bdee7cb"}, + {file = "coverage-7.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600a1d4106fe66f41e5d0136dfbc68fe7200a5cbe85610ddf094f8f22e1b0300"}, + {file = "coverage-7.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a876e4c3e5a2a1715a6608906aa5a2e0475b9c0f68343c2ada98110512ab1d8"}, + {file = "coverage-7.9.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81f34346dd63010453922c8e628a52ea2d2ccd73cb2487f7700ac531b247c8a5"}, + {file = "coverage-7.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:888f8eee13f2377ce86d44f338968eedec3291876b0b8a7289247ba52cb984cd"}, + {file = "coverage-7.9.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9969ef1e69b8c8e1e70d591f91bbc37fc9a3621e447525d1602801a24ceda898"}, + {file = "coverage-7.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:60c458224331ee3f1a5b472773e4a085cc27a86a0b48205409d364272d67140d"}, + {file = "coverage-7.9.1-cp312-cp312-win32.whl", hash = "sha256:5f646a99a8c2b3ff4c6a6e081f78fad0dde275cd59f8f49dc4eab2e394332e74"}, + {file = "coverage-7.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:30f445f85c353090b83e552dcbbdad3ec84c7967e108c3ae54556ca69955563e"}, + {file = "coverage-7.9.1-cp312-cp312-win_arm64.whl", hash = "sha256:af41da5dca398d3474129c58cb2b106a5d93bbb196be0d307ac82311ca234342"}, + {file = "coverage-7.9.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:31324f18d5969feef7344a932c32428a2d1a3e50b15a6404e97cba1cc9b2c631"}, + {file = "coverage-7.9.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0c804506d624e8a20fb3108764c52e0eef664e29d21692afa375e0dd98dc384f"}, + {file = "coverage-7.9.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef64c27bc40189f36fcc50c3fb8f16ccda73b6a0b80d9bd6e6ce4cffcd810bbd"}, + {file = "coverage-7.9.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4fe2348cc6ec372e25adec0219ee2334a68d2f5222e0cba9c0d613394e12d86"}, + {file = "coverage-7.9.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34ed2186fe52fcc24d4561041979a0dec69adae7bce2ae8d1c49eace13e55c43"}, + {file = "coverage-7.9.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:25308bd3d00d5eedd5ae7d4357161f4df743e3c0240fa773ee1b0f75e6c7c0f1"}, + {file = "coverage-7.9.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:73e9439310f65d55a5a1e0564b48e34f5369bee943d72c88378f2d576f5a5751"}, + {file = "coverage-7.9.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37ab6be0859141b53aa89412a82454b482c81cf750de4f29223d52268a86de67"}, + {file = "coverage-7.9.1-cp313-cp313-win32.whl", hash = "sha256:64bdd969456e2d02a8b08aa047a92d269c7ac1f47e0c977675d550c9a0863643"}, + {file = "coverage-7.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:be9e3f68ca9edb897c2184ad0eee815c635565dbe7a0e7e814dc1f7cbab92c0a"}, + {file = "coverage-7.9.1-cp313-cp313-win_arm64.whl", hash = "sha256:1c503289ffef1d5105d91bbb4d62cbe4b14bec4d13ca225f9c73cde9bb46207d"}, + {file = "coverage-7.9.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0b3496922cb5f4215bf5caaef4cf12364a26b0be82e9ed6d050f3352cf2d7ef0"}, + {file = "coverage-7.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9565c3ab1c93310569ec0d86b017f128f027cab0b622b7af288696d7ed43a16d"}, + {file = "coverage-7.9.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2241ad5dbf79ae1d9c08fe52b36d03ca122fb9ac6bca0f34439e99f8327ac89f"}, + {file = "coverage-7.9.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bb5838701ca68b10ebc0937dbd0eb81974bac54447c55cd58dea5bca8451029"}, + {file = "coverage-7.9.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30a25f814591a8c0c5372c11ac8967f669b97444c47fd794926e175c4047ece"}, + {file = "coverage-7.9.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2d04b16a6062516df97969f1ae7efd0de9c31eb6ebdceaa0d213b21c0ca1a683"}, + {file = "coverage-7.9.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7931b9e249edefb07cd6ae10c702788546341d5fe44db5b6108a25da4dca513f"}, + {file = "coverage-7.9.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:52e92b01041151bf607ee858e5a56c62d4b70f4dac85b8c8cb7fb8a351ab2c10"}, + {file = "coverage-7.9.1-cp313-cp313t-win32.whl", hash = "sha256:684e2110ed84fd1ca5f40e89aa44adf1729dc85444004111aa01866507adf363"}, + {file = "coverage-7.9.1-cp313-cp313t-win_amd64.whl", hash = "sha256:437c576979e4db840539674e68c84b3cda82bc824dd138d56bead1435f1cb5d7"}, + {file = "coverage-7.9.1-cp313-cp313t-win_arm64.whl", hash = "sha256:18a0912944d70aaf5f399e350445738a1a20b50fbea788f640751c2ed9208b6c"}, + {file = "coverage-7.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f424507f57878e424d9a95dc4ead3fbdd72fd201e404e861e465f28ea469951"}, + {file = "coverage-7.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:535fde4001b2783ac80865d90e7cc7798b6b126f4cd8a8c54acfe76804e54e58"}, + {file = "coverage-7.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02532fd3290bb8fa6bec876520842428e2a6ed6c27014eca81b031c2d30e3f71"}, + {file = "coverage-7.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56f5eb308b17bca3bbff810f55ee26d51926d9f89ba92707ee41d3c061257e55"}, + {file = "coverage-7.9.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfa447506c1a52271f1b0de3f42ea0fa14676052549095e378d5bff1c505ff7b"}, + {file = "coverage-7.9.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9ca8e220006966b4a7b68e8984a6aee645a0384b0769e829ba60281fe61ec4f7"}, + {file = "coverage-7.9.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:49f1d0788ba5b7ba65933f3a18864117c6506619f5ca80326b478f72acf3f385"}, + {file = "coverage-7.9.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:68cd53aec6f45b8e4724c0950ce86eacb775c6be01ce6e3669fe4f3a21e768ed"}, + {file = "coverage-7.9.1-cp39-cp39-win32.whl", hash = "sha256:95335095b6c7b1cc14c3f3f17d5452ce677e8490d101698562b2ffcacc304c8d"}, + {file = "coverage-7.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:e1b5191d1648acc439b24721caab2fd0c86679d8549ed2c84d5a7ec1bedcc244"}, + {file = "coverage-7.9.1-pp39.pp310.pp311-none-any.whl", hash = "sha256:db0f04118d1db74db6c9e1cb1898532c7dcc220f1d2718f058601f7c3f499514"}, + {file = "coverage-7.9.1-py3-none-any.whl", hash = "sha256:66b974b145aa189516b6bf2d8423e888b742517d37872f6ee4c5be0073bd9a3c"}, + {file = "coverage-7.9.1.tar.gz", hash = "sha256:6cf43c78c4282708a28e466316935ec7489a9c487518a77fa68f716c67909cec"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] + [[package]] name = "dill" version = "0.3.9" description = "serialize all of Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, @@ -207,6 +388,7 @@ version = "2.0.0" description = "An implementation of lxml.xmlfile for the standard library" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, @@ -218,6 +400,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -232,6 +416,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -246,6 +431,7 @@ version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -257,6 +443,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -271,6 +458,7 @@ version = "4.3.3" description = "LZ4 Bindings for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, @@ -321,6 +509,7 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -332,6 +521,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -391,6 +581,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -402,6 +593,8 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] +markers = "python_version < \"3.10\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -439,6 +632,8 @@ version = "2.2.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" +groups = ["main", "dev"] +markers = "python_version >= \"3.10\"" files = [ {file = "numpy-2.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8146f3550d627252269ac42ae660281d673eb6f8b32f113538e0cc2a9aed42b9"}, {file = "numpy-2.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e642d86b8f956098b564a45e6f6ce68a22c2c97a04f5acd3f221f57b8cb850ae"}, @@ -503,6 +698,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -519,6 +715,7 @@ version = "3.1.5" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"}, {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, @@ -533,6 +730,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -544,6 +742,8 @@ version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -573,11 +773,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, -] +numpy = {version = ">=1.20.3", markers = "python_version < \"3.10\""} python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" @@ -611,6 +807,8 @@ version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" files = [ {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, @@ -657,7 +855,11 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, +] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.7" @@ -693,6 +895,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -704,6 +907,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -720,6 +924,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -735,6 +940,8 @@ version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, @@ -786,6 +993,8 @@ version = "19.0.1" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, @@ -840,6 +1049,7 @@ version = "3.2.7" description = "python code static checker" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "pylint-3.2.7-py3-none-any.whl", hash = "sha256:02f4aedeac91be69fb3b4bea997ce580a4ac68ce58b89eaefeaf06749df73f4b"}, {file = "pylint-3.2.7.tar.gz", hash = "sha256:1b7a721b575eaeaa7d39db076b6e7743c993ea44f57979127c517c6c572c803e"}, @@ -851,7 +1061,7 @@ colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=0.3.6", markers = "python_version == \"3.11\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -870,6 +1080,7 @@ version = "7.4.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, @@ -886,12 +1097,32 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + [[package]] name = "pytest-dotenv" version = "0.5.2" description = "A py.test plugin that parses environment files before running tests" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, @@ -907,6 +1138,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -921,6 +1153,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -935,6 +1168,7 @@ version = "2025.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, @@ -946,6 +1180,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -967,6 +1202,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -978,6 +1214,7 @@ version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] @@ -996,6 +1233,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1037,6 +1276,7 @@ version = "0.13.2" description = "Style preserving TOML library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, @@ -1048,6 +1288,7 @@ version = "4.13.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5"}, {file = "typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b"}, @@ -1059,6 +1300,7 @@ version = "2025.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, @@ -1070,13 +1312,14 @@ version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1085,6 +1328,6 @@ zstd = ["zstandard (>=0.18.0)"] pyarrow = ["pyarrow", "pyarrow"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0bd6a6a019693a69a3da5ae312cea625ea73dfc5832b1e4051c7c7d1e76553d8" +content-hash = "7565c2cfcd646d789c9da8fd7b9f33cc1d592c434d3fdf1cf6063cbb0362dc10" diff --git a/pyproject.toml b/pyproject.toml index 7b95a5097..3def9abdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ numpy = [ "Homepage" = "https://github.com/databricks/databricks-sql-python" "Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues" +[tool.poetry.group.dev.dependencies] +pytest-cov = "4.1.0" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index d100e3c72..12ba1ee20 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -154,16 +154,6 @@ def fetchall(self) -> List[Row]: """Fetch all remaining rows of a query result.""" pass - @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """Fetch the next set of rows as an Arrow table.""" - pass - - @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": - """Fetch all remaining rows as an Arrow table.""" - pass - def close(self) -> None: """ Close the result set. @@ -499,7 +489,7 @@ def __init__( # Initialize queue for result data if not provided self.results = results_queue or JsonQueue([]) - def _convert_json_rows(self, rows): + def _convert_json_table(self, rows): """ Convert raw data rows to Row objects with named columns based on description. Args: @@ -514,59 +504,6 @@ def _convert_json_rows(self, rows): ResultRow = Row(*column_names) return [ResultRow(*row) for row in rows] - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows as an Arrow table. - - Args: - size: Number of rows to fetch - - Returns: - PyArrow Table containing the fetched rows - - Raises: - ImportError: If PyArrow is not installed - ValueError: If size is negative - """ - if size < 0: - raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - - results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows - - while n_remaining_rows > 0: - partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables([results, partial_results]) - n_remaining_rows = n_remaining_rows - partial_results.num_rows - self._next_row_index += partial_results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """ - Fetch all remaining rows as an Arrow table. - - Returns: - PyArrow Table containing all remaining rows - - Raises: - ImportError: If PyArrow is not installed - """ - results = self.results.remaining_rows() - self._next_row_index += results.num_rows - - # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table - # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: - data = { - name: col - for name, col in zip(results.column_names, results.column_table) - } - return pyarrow.Table.from_pydict(data) - - return results - def fetchmany_json(self, size: int): """ Fetch the next set of rows as a columnar table. @@ -584,15 +521,8 @@ def fetchmany_json(self, size: int): raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") results = self.results.next_n_rows(size) - n_remaining_rows = size - len(results) self._next_row_index += len(results) - while n_remaining_rows > 0: - partial_results = self.results.next_n_rows(n_remaining_rows) - results = results + partial_results - n_remaining_rows = n_remaining_rows - len(partial_results) - self._next_row_index += len(partial_results) - return results def fetchall_json(self): @@ -616,7 +546,7 @@ def fetchone(self) -> Optional[Row]: A single Row object or None if no more rows are available """ if isinstance(self.results, JsonQueue): - res = self._convert_json_rows(self.fetchmany_json(1)) + res = self._convert_json_table(self.fetchmany_json(1)) else: raise NotImplementedError("fetchone only supported for JSON data") @@ -636,7 +566,7 @@ def fetchmany(self, size: int) -> List[Row]: ValueError: If size is negative """ if isinstance(self.results, JsonQueue): - return self._convert_json_rows(self.fetchmany_json(size)) + return self._convert_json_table(self.fetchmany_json(size)) else: raise NotImplementedError("fetchmany only supported for JSON data") @@ -648,6 +578,6 @@ def fetchall(self) -> List[Row]: List of Row objects containing all remaining rows """ if isinstance(self.results, JsonQueue): - return self._convert_json_rows(self.fetchall_json()) + return self._convert_json_table(self.fetchall_json()) else: raise NotImplementedError("fetchall only supported for JSON data") diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 846e9e007..3fef0ebab 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -21,6 +21,7 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True + connection.disable_pandas = False return connection @pytest.fixture @@ -260,7 +261,7 @@ def test_convert_json_rows( ("col2", "INT", None, None, None, None, None), ] rows = [["value1", 123], ["value2", 456]] - converted_rows = result_set._convert_json_rows(rows) + converted_rows = result_set._convert_json_table(rows) assert len(converted_rows) == 2 assert converted_rows[0].col1 == "value1" @@ -270,7 +271,7 @@ def test_convert_json_rows( # Test with no description result_set.description = None - converted_rows = result_set._convert_json_rows(rows) + converted_rows = result_set._convert_json_table(rows) assert converted_rows == rows # Test with empty rows @@ -278,5 +279,165 @@ def test_convert_json_rows( ("col1", "STRING", None, None, None, None, None), ("col2", "INT", None, None, None, None, None), ] - converted_rows = result_set._convert_json_rows([]) + converted_rows = result_set._convert_json_table([]) assert converted_rows == [] + + @pytest.fixture + def mock_arrow_queue(self): + """Create a mock queue that returns PyArrow tables.""" + mock_queue = Mock() + + # Mock PyArrow Table for next_n_rows + mock_table1 = Mock() + mock_table1.num_rows = 2 + mock_queue.next_n_rows.return_value = mock_table1 + + # Mock PyArrow Table for remaining_rows + mock_table2 = Mock() + mock_table2.num_rows = 3 + mock_queue.remaining_rows.return_value = mock_table2 + + return mock_queue + + @patch("pyarrow.concat_tables") + def test_fetchmany_arrow( + self, + mock_concat_tables, + mock_connection, + mock_sea_client, + execute_response, + mock_arrow_queue, + ): + """Test fetchmany_arrow method.""" + # Setup mock for pyarrow.concat_tables + mock_concat_result = Mock() + mock_concat_result.num_rows = 3 + mock_concat_tables.return_value = mock_concat_result + + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_arrow_queue + + # Test with specific size + result = result_set.fetchmany_arrow(5) + + # Verify next_n_rows was called with the correct size + mock_arrow_queue.next_n_rows.assert_called_with(5) + + # Verify _next_row_index was updated + assert result_set._next_row_index == 2 + + # Test with negative size + with pytest.raises( + ValueError, match="size argument for fetchmany is -1 but must be >= 0" + ): + result_set.fetchmany_arrow(-1) + + def test_fetchall_arrow( + self, mock_connection, mock_sea_client, execute_response, mock_arrow_queue + ): + """Test fetchall_arrow method.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_arrow_queue + + # Test fetchall_arrow + result = result_set.fetchall_arrow() + + # Verify remaining_rows was called + mock_arrow_queue.remaining_rows.assert_called_once() + + # Verify _next_row_index was updated + assert result_set._next_row_index == 3 + + def test_fetchone( + self, mock_connection, mock_sea_client, execute_response, mock_json_queue + ): + """Test fetchone method.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_json_queue + result_set.description = [ + ("col1", "STRING", None, None, None, None, None), + ("col2", "INT", None, None, None, None, None), + ] + + # Mock fetchmany_json to return a single row + mock_json_queue.next_n_rows.return_value = [["value1", 123]] + + # Test fetchone + row = result_set.fetchone() + assert row is not None + assert row.col1 == "value1" + assert row.col2 == 123 + + # Test fetchone with no results + mock_json_queue.next_n_rows.return_value = [] + row = result_set.fetchone() + assert row is None + + # Test fetchone with non-JsonQueue + result_set.results = Mock() + result_set.results.__class__ = type("NotJsonQueue", (), {}) + + with pytest.raises( + NotImplementedError, match="fetchone only supported for JSON data" + ): + result_set.fetchone() + + def test_fetchmany_with_non_json_queue( + self, mock_connection, mock_sea_client, execute_response + ): + """Test fetchmany with a non-JsonQueue results object.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Set results to a non-JsonQueue object + result_set.results = Mock() + result_set.results.__class__ = type("NotJsonQueue", (), {}) + + with pytest.raises( + NotImplementedError, match="fetchmany only supported for JSON data" + ): + result_set.fetchmany(2) + + def test_fetchall_with_non_json_queue( + self, mock_connection, mock_sea_client, execute_response + ): + """Test fetchall with a non-JsonQueue results object.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Set results to a non-JsonQueue object + result_set.results = Mock() + result_set.results.__class__ = type("NotJsonQueue", (), {}) + + with pytest.raises( + NotImplementedError, match="fetchall only supported for JSON data" + ): + result_set.fetchall() From 34a7f66b9c2e6fd76ffffcbdb24ce7ee66c7c58c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 16 Jun 2025 11:27:45 +0000 Subject: [PATCH 117/204] update unit tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 286 +++++++++++++++++++++++------- 1 file changed, 223 insertions(+), 63 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 3fef0ebab..d5e2b4c7b 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -7,10 +7,14 @@ import pytest from unittest.mock import patch, MagicMock, Mock +import logging -from databricks.sql.result_set import SeaResultSet -from databricks.sql.utils import JsonQueue +from databricks.sql.result_set import SeaResultSet, ResultSet +from databricks.sql.utils import JsonQueue, ResultSetQueue +from databricks.sql.types import Row +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.types import CommandId, CommandState, BackendType +from databricks.sql.exc import RequestError, CursorAlreadyClosedError class TestSeaResultSet: @@ -299,67 +303,6 @@ def mock_arrow_queue(self): return mock_queue - @patch("pyarrow.concat_tables") - def test_fetchmany_arrow( - self, - mock_concat_tables, - mock_connection, - mock_sea_client, - execute_response, - mock_arrow_queue, - ): - """Test fetchmany_arrow method.""" - # Setup mock for pyarrow.concat_tables - mock_concat_result = Mock() - mock_concat_result.num_rows = 3 - mock_concat_tables.return_value = mock_concat_result - - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.results = mock_arrow_queue - - # Test with specific size - result = result_set.fetchmany_arrow(5) - - # Verify next_n_rows was called with the correct size - mock_arrow_queue.next_n_rows.assert_called_with(5) - - # Verify _next_row_index was updated - assert result_set._next_row_index == 2 - - # Test with negative size - with pytest.raises( - ValueError, match="size argument for fetchmany is -1 but must be >= 0" - ): - result_set.fetchmany_arrow(-1) - - def test_fetchall_arrow( - self, mock_connection, mock_sea_client, execute_response, mock_arrow_queue - ): - """Test fetchall_arrow method.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.results = mock_arrow_queue - - # Test fetchall_arrow - result = result_set.fetchall_arrow() - - # Verify remaining_rows was called - mock_arrow_queue.remaining_rows.assert_called_once() - - # Verify _next_row_index was updated - assert result_set._next_row_index == 3 - def test_fetchone( self, mock_connection, mock_sea_client, execute_response, mock_json_queue ): @@ -441,3 +384,220 @@ def test_fetchall_with_non_json_queue( NotImplementedError, match="fetchall only supported for JSON data" ): result_set.fetchall() + + def test_iterator_protocol( + self, mock_connection, mock_sea_client, execute_response, mock_json_queue + ): + """Test the iterator protocol (__iter__) implementation.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_json_queue + result_set.description = [ + ("test_value", "INT", None, None, None, None, None), + ] + + # Mock fetchone to return a sequence of values and then None + with patch.object(result_set, "fetchone") as mock_fetchone: + mock_fetchone.side_effect = [ + Row("test_value")(100), + Row("test_value")(200), + Row("test_value")(300), + None, + ] + + # Test iterating over the result set + rows = list(result_set) + assert len(rows) == 3 + assert rows[0].test_value == 100 + assert rows[1].test_value == 200 + assert rows[2].test_value == 300 + + def test_rownumber_property( + self, mock_connection, mock_sea_client, execute_response, mock_json_queue + ): + """Test the rownumber property.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.results = mock_json_queue + + # Initial row number should be 0 + assert result_set.rownumber == 0 + + # After fetching rows, row number should be updated + mock_json_queue.next_n_rows.return_value = [["value1"]] + result_set.fetchmany_json(2) + result_set._next_row_index = 2 + assert result_set.rownumber == 2 + + # After fetching more rows, row number should be incremented + mock_json_queue.next_n_rows.return_value = [["value3"]] + result_set.fetchmany_json(1) + result_set._next_row_index = 3 + assert result_set.rownumber == 3 + + def test_is_staging_operation_property(self, mock_connection, mock_sea_client): + """Test the is_staging_operation property.""" + # Create a response with staging operation set to True + staging_response = Mock() + staging_response.command_id = CommandId.from_sea_statement_id( + "test-staging-123" + ) + staging_response.status = CommandState.SUCCEEDED + staging_response.has_been_closed_server_side = False + staging_response.description = [] + staging_response.is_staging_operation = True + staging_response.lz4_compressed = False + staging_response.arrow_schema_bytes = b"" + + # Create a result set with staging operation + result_set = SeaResultSet( + connection=mock_connection, + execute_response=staging_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify the is_staging_operation property + assert result_set.is_staging_operation is True + + def test_init_with_result_data( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with result data.""" + # Create sample result data with a mock + result_data = Mock(spec=ResultData) + result_data.data = [["value1", 123], ["value2", 456]] + result_data.external_links = None + + manifest = Mock(spec=ResultManifest) + + # Mock the SeaResultSetQueueFactory.build_queue method + with patch( + "databricks.sql.result_set.SeaResultSetQueueFactory" + ) as factory_mock: + # Create a mock JsonQueue + mock_queue = Mock(spec=JsonQueue) + factory_mock.build_queue.return_value = mock_queue + + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=result_data, + manifest=manifest, + ) + + # Verify the factory was called with the right parameters + factory_mock.build_queue.assert_called_once_with( + result_data, + manifest, + str(execute_response.command_id.to_sea_statement_id()), + description=execute_response.description, + max_download_threads=mock_sea_client.max_download_threads, + ssl_options=mock_sea_client.ssl_options, + sea_client=mock_sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Verify the results queue was set correctly + assert result_set.results == mock_queue + + def test_close_with_request_error( + self, mock_connection, mock_sea_client, execute_response + ): + """Test closing a result set when a RequestError is raised.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Create a patched version of the close method that doesn't check e.args[1] + with patch("databricks.sql.result_set.ResultSet.close") as mock_close: + # Call the close method + result_set.close() + + # Verify the parent's close method was called + mock_close.assert_called_once() + + def test_init_with_empty_result_data( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with empty result data.""" + # Create sample result data with a mock + result_data = Mock(spec=ResultData) + result_data.data = None + result_data.external_links = None + + manifest = Mock(spec=ResultManifest) + + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=result_data, + manifest=manifest, + ) + + # Verify an empty JsonQueue was created + assert isinstance(result_set.results, JsonQueue) + assert result_set.results.data_array == [] + + def test_init_without_result_data( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet without result data.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify an empty JsonQueue was created + assert isinstance(result_set.results, JsonQueue) + assert result_set.results.data_array == [] + + def test_init_with_external_links( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with external links.""" + # Create sample result data with external links + result_data = Mock(spec=ResultData) + result_data.data = None + result_data.external_links = ["link1", "link2"] + + manifest = Mock(spec=ResultManifest) + + # This should raise NotImplementedError + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=result_data, + manifest=manifest, + ) From 715cc135f2c39329210194c1a3e9c454f1792601 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 16 Jun 2025 11:29:09 +0000 Subject: [PATCH 118/204] remove accidental venv changes Signed-off-by: varun-edachali-dbx --- poetry.lock | 265 ++----------------------------------------------- pyproject.toml | 3 - 2 files changed, 11 insertions(+), 257 deletions(-) diff --git a/poetry.lock b/poetry.lock index 12d984f22..1bc396c9d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "astroid" @@ -6,7 +6,6 @@ version = "3.2.4" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" -groups = ["dev"] files = [ {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, @@ -21,7 +20,6 @@ version = "22.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.7" -groups = ["dev"] files = [ {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, @@ -57,7 +55,6 @@ version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" -groups = ["main"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -69,7 +66,6 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" -groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -171,7 +167,6 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" -groups = ["dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -186,193 +181,17 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["dev"] -markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "coverage" -version = "7.6.1" -description = "Code coverage measurement for Python" -optional = false -python-versions = ">=3.8" -groups = ["dev"] -markers = "python_version < \"3.10\"" -files = [ - {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, - {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, - {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, - {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, - {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, - {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, - {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, - {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, - {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, - {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, - {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, - {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, - {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, - {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, - {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, - {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, - {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, - {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, - {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, - {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, - {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, - {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, - {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, - {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, - {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, - {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, - {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, - {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, - {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, - {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, - {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, - {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, - {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, - {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, - {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, - {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, - {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, - {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, - {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, - {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, - {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, - {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, - {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, - {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, - {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, - {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, - {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, - {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, - {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, - {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, - {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, - {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, - {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, - {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, - {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, - {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, - {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, - {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, - {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, - {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, - {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, - {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, - {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, - {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, - {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, - {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, - {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, - {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, - {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, - {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, - {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, - {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, -] - -[package.dependencies] -tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} - -[package.extras] -toml = ["tomli ; python_full_version <= \"3.11.0a6\""] - -[[package]] -name = "coverage" -version = "7.9.1" -description = "Code coverage measurement for Python" -optional = false -python-versions = ">=3.9" -groups = ["dev"] -markers = "python_version >= \"3.10\"" -files = [ - {file = "coverage-7.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cc94d7c5e8423920787c33d811c0be67b7be83c705f001f7180c7b186dcf10ca"}, - {file = "coverage-7.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16aa0830d0c08a2c40c264cef801db8bc4fc0e1892782e45bcacbd5889270509"}, - {file = "coverage-7.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf95981b126f23db63e9dbe4cf65bd71f9a6305696fa5e2262693bc4e2183f5b"}, - {file = "coverage-7.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f05031cf21699785cd47cb7485f67df619e7bcdae38e0fde40d23d3d0210d3c3"}, - {file = "coverage-7.9.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4fbcab8764dc072cb651a4bcda4d11fb5658a1d8d68842a862a6610bd8cfa3"}, - {file = "coverage-7.9.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0f16649a7330ec307942ed27d06ee7e7a38417144620bb3d6e9a18ded8a2d3e5"}, - {file = "coverage-7.9.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cea0a27a89e6432705fffc178064503508e3c0184b4f061700e771a09de58187"}, - {file = "coverage-7.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e980b53a959fa53b6f05343afbd1e6f44a23ed6c23c4b4c56c6662bbb40c82ce"}, - {file = "coverage-7.9.1-cp310-cp310-win32.whl", hash = "sha256:70760b4c5560be6ca70d11f8988ee6542b003f982b32f83d5ac0b72476607b70"}, - {file = "coverage-7.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a66e8f628b71f78c0e0342003d53b53101ba4e00ea8dabb799d9dba0abbbcebe"}, - {file = "coverage-7.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95c765060e65c692da2d2f51a9499c5e9f5cf5453aeaf1420e3fc847cc060582"}, - {file = "coverage-7.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba383dc6afd5ec5b7a0d0c23d38895db0e15bcba7fb0fa8901f245267ac30d86"}, - {file = "coverage-7.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37ae0383f13cbdcf1e5e7014489b0d71cc0106458878ccde52e8a12ced4298ed"}, - {file = "coverage-7.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:69aa417a030bf11ec46149636314c24c8d60fadb12fc0ee8f10fda0d918c879d"}, - {file = "coverage-7.9.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a4be2a28656afe279b34d4f91c3e26eccf2f85500d4a4ff0b1f8b54bf807338"}, - {file = "coverage-7.9.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:382e7ddd5289f140259b610e5f5c58f713d025cb2f66d0eb17e68d0a94278875"}, - {file = "coverage-7.9.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e5532482344186c543c37bfad0ee6069e8ae4fc38d073b8bc836fc8f03c9e250"}, - {file = "coverage-7.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a39d18b3f50cc121d0ce3838d32d58bd1d15dab89c910358ebefc3665712256c"}, - {file = "coverage-7.9.1-cp311-cp311-win32.whl", hash = "sha256:dd24bd8d77c98557880def750782df77ab2b6885a18483dc8588792247174b32"}, - {file = "coverage-7.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:6b55ad10a35a21b8015eabddc9ba31eb590f54adc9cd39bcf09ff5349fd52125"}, - {file = "coverage-7.9.1-cp311-cp311-win_arm64.whl", hash = "sha256:6ad935f0016be24c0e97fc8c40c465f9c4b85cbbe6eac48934c0dc4d2568321e"}, - {file = "coverage-7.9.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8de12b4b87c20de895f10567639c0797b621b22897b0af3ce4b4e204a743626"}, - {file = "coverage-7.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5add197315a054e92cee1b5f686a2bcba60c4c3e66ee3de77ace6c867bdee7cb"}, - {file = "coverage-7.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600a1d4106fe66f41e5d0136dfbc68fe7200a5cbe85610ddf094f8f22e1b0300"}, - {file = "coverage-7.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a876e4c3e5a2a1715a6608906aa5a2e0475b9c0f68343c2ada98110512ab1d8"}, - {file = "coverage-7.9.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81f34346dd63010453922c8e628a52ea2d2ccd73cb2487f7700ac531b247c8a5"}, - {file = "coverage-7.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:888f8eee13f2377ce86d44f338968eedec3291876b0b8a7289247ba52cb984cd"}, - {file = "coverage-7.9.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9969ef1e69b8c8e1e70d591f91bbc37fc9a3621e447525d1602801a24ceda898"}, - {file = "coverage-7.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:60c458224331ee3f1a5b472773e4a085cc27a86a0b48205409d364272d67140d"}, - {file = "coverage-7.9.1-cp312-cp312-win32.whl", hash = "sha256:5f646a99a8c2b3ff4c6a6e081f78fad0dde275cd59f8f49dc4eab2e394332e74"}, - {file = "coverage-7.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:30f445f85c353090b83e552dcbbdad3ec84c7967e108c3ae54556ca69955563e"}, - {file = "coverage-7.9.1-cp312-cp312-win_arm64.whl", hash = "sha256:af41da5dca398d3474129c58cb2b106a5d93bbb196be0d307ac82311ca234342"}, - {file = "coverage-7.9.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:31324f18d5969feef7344a932c32428a2d1a3e50b15a6404e97cba1cc9b2c631"}, - {file = "coverage-7.9.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0c804506d624e8a20fb3108764c52e0eef664e29d21692afa375e0dd98dc384f"}, - {file = "coverage-7.9.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef64c27bc40189f36fcc50c3fb8f16ccda73b6a0b80d9bd6e6ce4cffcd810bbd"}, - {file = "coverage-7.9.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4fe2348cc6ec372e25adec0219ee2334a68d2f5222e0cba9c0d613394e12d86"}, - {file = "coverage-7.9.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34ed2186fe52fcc24d4561041979a0dec69adae7bce2ae8d1c49eace13e55c43"}, - {file = "coverage-7.9.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:25308bd3d00d5eedd5ae7d4357161f4df743e3c0240fa773ee1b0f75e6c7c0f1"}, - {file = "coverage-7.9.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:73e9439310f65d55a5a1e0564b48e34f5369bee943d72c88378f2d576f5a5751"}, - {file = "coverage-7.9.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37ab6be0859141b53aa89412a82454b482c81cf750de4f29223d52268a86de67"}, - {file = "coverage-7.9.1-cp313-cp313-win32.whl", hash = "sha256:64bdd969456e2d02a8b08aa047a92d269c7ac1f47e0c977675d550c9a0863643"}, - {file = "coverage-7.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:be9e3f68ca9edb897c2184ad0eee815c635565dbe7a0e7e814dc1f7cbab92c0a"}, - {file = "coverage-7.9.1-cp313-cp313-win_arm64.whl", hash = "sha256:1c503289ffef1d5105d91bbb4d62cbe4b14bec4d13ca225f9c73cde9bb46207d"}, - {file = "coverage-7.9.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0b3496922cb5f4215bf5caaef4cf12364a26b0be82e9ed6d050f3352cf2d7ef0"}, - {file = "coverage-7.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9565c3ab1c93310569ec0d86b017f128f027cab0b622b7af288696d7ed43a16d"}, - {file = "coverage-7.9.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2241ad5dbf79ae1d9c08fe52b36d03ca122fb9ac6bca0f34439e99f8327ac89f"}, - {file = "coverage-7.9.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bb5838701ca68b10ebc0937dbd0eb81974bac54447c55cd58dea5bca8451029"}, - {file = "coverage-7.9.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30a25f814591a8c0c5372c11ac8967f669b97444c47fd794926e175c4047ece"}, - {file = "coverage-7.9.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2d04b16a6062516df97969f1ae7efd0de9c31eb6ebdceaa0d213b21c0ca1a683"}, - {file = "coverage-7.9.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7931b9e249edefb07cd6ae10c702788546341d5fe44db5b6108a25da4dca513f"}, - {file = "coverage-7.9.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:52e92b01041151bf607ee858e5a56c62d4b70f4dac85b8c8cb7fb8a351ab2c10"}, - {file = "coverage-7.9.1-cp313-cp313t-win32.whl", hash = "sha256:684e2110ed84fd1ca5f40e89aa44adf1729dc85444004111aa01866507adf363"}, - {file = "coverage-7.9.1-cp313-cp313t-win_amd64.whl", hash = "sha256:437c576979e4db840539674e68c84b3cda82bc824dd138d56bead1435f1cb5d7"}, - {file = "coverage-7.9.1-cp313-cp313t-win_arm64.whl", hash = "sha256:18a0912944d70aaf5f399e350445738a1a20b50fbea788f640751c2ed9208b6c"}, - {file = "coverage-7.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f424507f57878e424d9a95dc4ead3fbdd72fd201e404e861e465f28ea469951"}, - {file = "coverage-7.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:535fde4001b2783ac80865d90e7cc7798b6b126f4cd8a8c54acfe76804e54e58"}, - {file = "coverage-7.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02532fd3290bb8fa6bec876520842428e2a6ed6c27014eca81b031c2d30e3f71"}, - {file = "coverage-7.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56f5eb308b17bca3bbff810f55ee26d51926d9f89ba92707ee41d3c061257e55"}, - {file = "coverage-7.9.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfa447506c1a52271f1b0de3f42ea0fa14676052549095e378d5bff1c505ff7b"}, - {file = "coverage-7.9.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9ca8e220006966b4a7b68e8984a6aee645a0384b0769e829ba60281fe61ec4f7"}, - {file = "coverage-7.9.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:49f1d0788ba5b7ba65933f3a18864117c6506619f5ca80326b478f72acf3f385"}, - {file = "coverage-7.9.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:68cd53aec6f45b8e4724c0950ce86eacb775c6be01ce6e3669fe4f3a21e768ed"}, - {file = "coverage-7.9.1-cp39-cp39-win32.whl", hash = "sha256:95335095b6c7b1cc14c3f3f17d5452ce677e8490d101698562b2ffcacc304c8d"}, - {file = "coverage-7.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:e1b5191d1648acc439b24721caab2fd0c86679d8549ed2c84d5a7ec1bedcc244"}, - {file = "coverage-7.9.1-pp39.pp310.pp311-none-any.whl", hash = "sha256:db0f04118d1db74db6c9e1cb1898532c7dcc220f1d2718f058601f7c3f499514"}, - {file = "coverage-7.9.1-py3-none-any.whl", hash = "sha256:66b974b145aa189516b6bf2d8423e888b742517d37872f6ee4c5be0073bd9a3c"}, - {file = "coverage-7.9.1.tar.gz", hash = "sha256:6cf43c78c4282708a28e466316935ec7489a9c487518a77fa68f716c67909cec"}, -] - -[package.dependencies] -tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} - -[package.extras] -toml = ["tomli ; python_full_version <= \"3.11.0a6\""] - [[package]] name = "dill" version = "0.3.9" description = "serialize all of Python" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, @@ -388,7 +207,6 @@ version = "2.0.0" description = "An implementation of lxml.xmlfile for the standard library" optional = false python-versions = ">=3.8" -groups = ["main"] files = [ {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, @@ -400,8 +218,6 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" -groups = ["dev"] -markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -416,7 +232,6 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -431,7 +246,6 @@ version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -443,7 +257,6 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" -groups = ["dev"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -458,7 +271,6 @@ version = "4.3.3" description = "LZ4 Bindings for Python" optional = false python-versions = ">=3.8" -groups = ["main"] files = [ {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, @@ -509,7 +321,6 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" -groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -521,7 +332,6 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -581,7 +391,6 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" -groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -593,8 +402,6 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" -groups = ["main", "dev"] -markers = "python_version < \"3.10\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -632,8 +439,6 @@ version = "2.2.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" -groups = ["main", "dev"] -markers = "python_version >= \"3.10\"" files = [ {file = "numpy-2.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8146f3550d627252269ac42ae660281d673eb6f8b32f113538e0cc2a9aed42b9"}, {file = "numpy-2.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e642d86b8f956098b564a45e6f6ce68a22c2c97a04f5acd3f221f57b8cb850ae"}, @@ -698,7 +503,6 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" -groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -715,7 +519,6 @@ version = "3.1.5" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" optional = false python-versions = ">=3.8" -groups = ["main"] files = [ {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"}, {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, @@ -730,7 +533,6 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -742,8 +544,6 @@ version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.8" -groups = ["main"] -markers = "python_version < \"3.10\"" files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -773,7 +573,11 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.20.3", markers = "python_version < \"3.10\""} +numpy = [ + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, +] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" @@ -807,8 +611,6 @@ version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" -groups = ["main"] -markers = "python_version >= \"3.10\"" files = [ {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, @@ -855,11 +657,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, -] +numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.7" @@ -895,7 +693,6 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -907,7 +704,6 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -924,7 +720,6 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -940,8 +735,6 @@ version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" -groups = ["main"] -markers = "python_version < \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, @@ -993,8 +786,6 @@ version = "19.0.1" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.9" -groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, @@ -1049,7 +840,6 @@ version = "3.2.7" description = "python code static checker" optional = false python-versions = ">=3.8.0" -groups = ["dev"] files = [ {file = "pylint-3.2.7-py3-none-any.whl", hash = "sha256:02f4aedeac91be69fb3b4bea997ce580a4ac68ce58b89eaefeaf06749df73f4b"}, {file = "pylint-3.2.7.tar.gz", hash = "sha256:1b7a721b575eaeaa7d39db076b6e7743c993ea44f57979127c517c6c572c803e"}, @@ -1061,7 +851,7 @@ colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version == \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -1080,7 +870,6 @@ version = "7.4.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" -groups = ["dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, @@ -1097,32 +886,12 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] -[[package]] -name = "pytest-cov" -version = "4.1.0" -description = "Pytest plugin for measuring coverage." -optional = false -python-versions = ">=3.7" -groups = ["dev"] -files = [ - {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, - {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, -] - -[package.dependencies] -coverage = {version = ">=5.2.1", extras = ["toml"]} -pytest = ">=4.6" - -[package.extras] -testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] - [[package]] name = "pytest-dotenv" version = "0.5.2" description = "A py.test plugin that parses environment files before running tests" optional = false python-versions = "*" -groups = ["dev"] files = [ {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, @@ -1138,7 +907,6 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -1153,7 +921,6 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -1168,7 +935,6 @@ version = "2025.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" -groups = ["main"] files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, @@ -1180,7 +946,6 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" -groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -1202,7 +967,6 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1214,7 +978,6 @@ version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" -groups = ["main"] files = [ {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] @@ -1233,8 +996,6 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" -groups = ["dev"] -markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1276,7 +1037,6 @@ version = "0.13.2" description = "Style preserving TOML library" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, @@ -1288,7 +1048,6 @@ version = "4.13.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5"}, {file = "typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b"}, @@ -1300,7 +1059,6 @@ version = "2025.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" -groups = ["main"] files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, @@ -1312,14 +1070,13 @@ version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" -groups = ["main"] files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] -brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1328,6 +1085,6 @@ zstd = ["zstandard (>=0.18.0)"] pyarrow = ["pyarrow", "pyarrow"] [metadata] -lock-version = "2.1" +lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "7565c2cfcd646d789c9da8fd7b9f33cc1d592c434d3fdf1cf6063cbb0362dc10" +content-hash = "0bd6a6a019693a69a3da5ae312cea625ea73dfc5832b1e4051c7c7d1e76553d8" diff --git a/pyproject.toml b/pyproject.toml index 3def9abdf..7b95a5097 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,6 @@ numpy = [ "Homepage" = "https://github.com/databricks/databricks-sql-python" "Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues" -[tool.poetry.group.dev.dependencies] -pytest-cov = "4.1.0" - [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From a0705bc455dd2eb6b29e666508df5c426b6c5d2a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:23:59 +0000 Subject: [PATCH 119/204] add fetchmany_arrow and fetchall_arrow Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 77 +---------------------- src/databricks/sql/result_set.py | 41 ++++++++++++ 2 files changed, 42 insertions(+), 76 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 1e4eb3253..79ab30c9e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -302,74 +302,6 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - def _get_schema_bytes(self, sea_response) -> Optional[bytes]: - """ - Extract schema bytes from the SEA response. - - For ARROW format, we need to get the schema bytes from the first chunk. - If the first chunk is not available, we need to get it from the server. - - Args: - sea_response: The response from the SEA API - - Returns: - bytes: The schema bytes or None if not available - """ - import requests - import lz4.frame - - # Check if we have the first chunk in the response - result_data = sea_response.get("result", {}) - external_links = result_data.get("external_links", []) - - if not external_links: - return None - - # Find the first chunk (chunk_index = 0) - first_chunk = None - for link in external_links: - if link.get("chunk_index") == 0: - first_chunk = link - break - - if not first_chunk: - # Try to fetch the first chunk from the server - statement_id = sea_response.get("statement_id") - if not statement_id: - return None - - chunks_response = self.get_chunk_links(statement_id, 0) - if not chunks_response.external_links: - return None - - first_chunk = chunks_response.external_links[0].__dict__ - - # Download the first chunk to get the schema bytes - external_link = first_chunk.get("external_link") - http_headers = first_chunk.get("http_headers", {}) - - if not external_link: - return None - - # Use requests to download the first chunk - http_response = requests.get( - external_link, - headers=http_headers, - verify=self.ssl_options.tls_verify, - ) - - if http_response.status_code != 200: - raise Error(f"Failed to download schema bytes: {http_response.text}") - - # Extract schema bytes from the Arrow file - # The schema is at the beginning of the file - data = http_response.content - if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME": - data = lz4.frame.decompress(data) - - # Return the schema bytes - return data - def _results_message_to_execute_response(self, sea_response, command_id): """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -412,13 +344,6 @@ def _results_message_to_execute_response(self, sea_response, command_id): ) description = columns if columns else None - # Extract schema bytes for Arrow format - schema_bytes = None - format = manifest_data.get("format") - if format == "ARROW_STREAM": - # For ARROW format, we need to get the schema bytes - schema_bytes = self._get_schema_bytes(sea_response) - # Check for compression lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME" @@ -473,7 +398,7 @@ def _results_message_to_execute_response(self, sea_response, command_id): has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=False, - arrow_schema_bytes=schema_bytes, + arrow_schema_bytes=None, result_format=manifest_data.get("format"), ) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 12ba1ee20..06462e92f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -154,6 +154,16 @@ def fetchall(self) -> List[Row]: """Fetch all remaining rows of a query result.""" pass + @abstractmethod + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """Fetch the next set of rows as an Arrow table.""" + pass + + @abstractmethod + def fetchall_arrow(self) -> "pyarrow.Table": + """Fetch all remaining rows as an Arrow table.""" + pass + def close(self) -> None: """ Close the result set. @@ -537,6 +547,37 @@ def fetchall_json(self): return results + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows as an Arrow table. + + Args: + size: Number of rows to fetch + + Returns: + PyArrow Table containing the fetched rows + + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative + """ + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """ + Fetch all remaining rows as an Arrow table. + """ + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + return results + def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, From f7c11b9c62452817aa52133ae826db97f914f98a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:28:07 +0000 Subject: [PATCH 120/204] remove accidental changes in sea backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index d82393bf0..b3171533f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -526,7 +526,7 @@ def test_get_execution_result( print(result) # Verify basic properties of the result - assert result.command_id.to_sea_statement_id() == "test-statement-123" + assert result.statement_id == "test-statement-123" assert result.status == CommandState.SUCCEEDED # Verify the HTTP request From 62298486dd4d4d20ee5503a32ab73ca70a609294 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:37:16 +0000 Subject: [PATCH 121/204] remove irrelevant changes Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/models/requests.py | 4 +- .../sql/backend/sea/models/responses.py | 2 +- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 130 ++++++++++-------- src/databricks/sql/result_set.py | 84 ++++++----- src/databricks/sql/utils.py | 6 +- 6 files changed, 131 insertions(+), 97 deletions(-) diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 8524275d4..4c5071dba 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CreateSessionRequest: - """Request to create a new session.""" + """Representation of a request to create a new session.""" warehouse_id: str session_confs: Optional[Dict[str, str]] = None @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DeleteSessionRequest: - """Request to delete a session.""" + """Representation of a request to delete a session.""" warehouse_id: str session_id: str diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 43283a8b0..dae37b1ae 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -146,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": @dataclass class CreateSessionResponse: - """Response from creating a new session.""" + """Representation of the response from creating a new session.""" session_id: str diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index f0b931ee4..fe292919c 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,7 +1,7 @@ import json import logging import requests -from typing import Callable, Dict, Any, Optional, Union, List, Tuple +from typing import Callable, Dict, Any, Optional, List, Tuple from urllib.parse import urljoin from databricks.sql.auth.authenticators import AuthProvider diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 48e9a115f..e824de1c2 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -3,24 +3,21 @@ import logging import math import time -import uuid import threading from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor -from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( CommandState, SessionId, CommandId, - BackendType, - guid_to_hex_id, ExecuteResponse, ) from databricks.sql.backend.utils import guid_to_hex_id + try: import pyarrow except ImportError: @@ -760,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state): ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation - has_more_rows = ( + + is_direct_results = ( (not direct_results) or (not direct_results.resultSet) or direct_results.resultSet.hasMoreRows ) + description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -780,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = None lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - if direct_results and direct_results.resultSet: - assert direct_results.resultSet.results.startRowOffset == 0 - assert direct_results.resultSetMetadata - - arrow_queue_opt = ResultSetQueueFactory.build_queue( - row_set_type=t_result_set_metadata_resp.resultFormat, - t_row_set=direct_results.resultSet.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) - else: - arrow_queue_opt = None - command_id = CommandId.from_thrift_handle(resp.operationHandle) status = CommandState.from_thrift_state(operation_state) if status is None: raise ValueError(f"Unknown command state: {operation_state}") - return ( - ExecuteResponse( - command_id=command_id, - status=status, - description=description, - has_more_rows=has_more_rows, - results_queue=arrow_queue_opt, - has_been_closed_server_side=has_been_closed_server_side, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - ), - schema_bytes, + execute_response = ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=t_result_set_metadata_resp.isStagingOperation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) + return execute_response, is_direct_results + def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": @@ -841,9 +822,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - lz4_compressed = t_result_set_metadata_resp.lz4Compressed - is_staging_operation = t_result_set_metadata_resp.isStagingOperation - has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) @@ -858,25 +836,21 @@ def get_execution_result( else: schema_bytes = None - queue = ResultSetQueueFactory.build_queue( - row_set_type=resp.resultSetMetadata.resultFormat, - t_row_set=resp.results, - arrow_schema_bytes=schema_bytes, - max_download_threads=self.max_download_threads, - lz4_compressed=lz4_compressed, - description=description, - ssl_options=self._ssl_options, - ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + is_direct_results = resp.hasMoreRows + + status = self.get_query_state(command_id) execute_response = ExecuteResponse( command_id=command_id, status=status, description=description, - has_more_rows=has_more_rows, - results_queue=queue, has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, + arrow_schema_bytes=schema_bytes, + result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( @@ -886,7 +860,10 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=schema_bytes, + t_row_set=resp.results, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -999,10 +976,14 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1010,7 +991,10 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_catalogs( @@ -1032,10 +1016,14 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1043,7 +1031,10 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_schemas( @@ -1069,10 +1060,14 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1080,7 +1075,10 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_tables( @@ -1110,10 +1108,14 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1121,7 +1123,10 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def get_columns( @@ -1151,10 +1156,14 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, arrow_schema_bytes = self._handle_execute_response( + execute_response, is_direct_results = self._handle_execute_response( resp, cursor ) + t_row_set = None + if resp.directResults and resp.directResults.resultSet: + t_row_set = resp.directResults.resultSet.results + return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, @@ -1162,7 +1171,10 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, - arrow_schema_bytes=arrow_schema_bytes, + t_row_set=t_row_set, + max_download_threads=self.max_download_threads, + ssl_options=self._ssl_options, + is_direct_results=is_direct_results, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index b2ecd00f0..38b8a3c2f 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,7 +41,7 @@ def __init__( command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, - has_more_rows: bool = False, + is_direct_results: bool = False, results_queue=None, description=None, is_staging_operation: bool = False, @@ -51,18 +51,18 @@ def __init__( """ A ResultSet manages the results of a single command. - Args: - connection: The parent connection - backend: The backend client - arraysize: The max number of rows to fetch at a time (PEP-249) - buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - command_id: The command ID - status: The command status - has_been_closed_server_side: Whether the command has been closed on the server - has_more_rows: Whether the command has more rows - results_queue: The results queue - description: column description of the results - is_staging_operation: Whether the command is a staging operation + Parameters: + :param connection: The parent connection + :param backend: The backend client + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch + :param command_id: The command ID + :param status: The command status + :param has_been_closed_server_side: Whether the command has been closed on the server + :param is_direct_results: Whether the command has more rows + :param results_queue: The results queue + :param description: column description of the results + :param is_staging_operation: Whether the command is a staging operation """ self.connection = connection @@ -74,7 +74,7 @@ def __init__( self.command_id = command_id self.status = status self.has_been_closed_server_side = has_been_closed_server_side - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results self.results = results_queue self._is_staging_operation = is_staging_operation self.lz4_compressed = lz4_compressed @@ -161,25 +161,47 @@ def __init__( buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, - arrow_schema_bytes: Optional[bytes] = None, + t_row_set=None, + max_download_threads: int = 10, + ssl_options=None, + is_direct_results: bool = True, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. - Args: - connection: The parent connection - execute_response: Response from the execute command - thrift_client: The ThriftDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - use_cloud_fetch: Whether to use cloud fetch for retrieving results - arrow_schema_bytes: Arrow schema bytes for the result set + Parameters: + :param connection: The parent connection + :param execute_response: Response from the execute command + :param thrift_client: The ThriftDatabricksClient instance for direct access + :param buffer_size_bytes: Buffer size for fetching results + :param arraysize: Default number of rows to fetch + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results + :param t_row_set: The TRowSet containing result data (if available) + :param max_download_threads: Maximum number of download threads for cloud fetch + :param ssl_options: SSL options for cloud fetch + :param is_direct_results: Whether there are more rows to fetch """ # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results + # Build the results queue if t_row_set is provided + results_queue = None + if t_row_set and execute_response.result_format is not None: + from databricks.sql.utils import ResultSetQueueFactory + + # Create the results queue using the provided format + results_queue = ResultSetQueueFactory.build_queue( + row_set_type=execute_response.result_format, + t_row_set=t_row_set, + arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + max_download_threads=max_download_threads, + lz4_compressed=execute_response.lz4_compressed, + description=execute_response.description, + ssl_options=ssl_options, + ) + # Call parent constructor with common attributes super().__init__( connection=connection, @@ -189,8 +211,8 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - has_more_rows=execute_response.has_more_rows, - results_queue=execute_response.results_queue, + is_direct_results=is_direct_results, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, @@ -202,7 +224,7 @@ def __init__( self._fill_results_buffer() def _fill_results_buffer(self): - results, has_more_rows = self.backend.fetch_results( + results, is_direct_results = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, @@ -213,7 +235,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self.is_direct_results = is_direct_results def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -297,7 +319,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -322,7 +344,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self.is_direct_results ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -337,7 +359,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -363,7 +385,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self.is_direct_results: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index edb13ef6d..d7b1b74b4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import re import lz4.frame @@ -57,7 +57,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -206,7 +206,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[List[Any]]] = None, + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. From fd5235606bbf307432e375a57e760319fc78709e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:39:42 +0000 Subject: [PATCH 122/204] remove un-necessary test changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/types.py | 8 +++--- tests/unit/test_client.py | 11 +++++--- tests/unit/test_fetches.py | 39 ++++++++++++++++------------- tests/unit/test_fetches_bench.py | 2 +- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 3d24a09a1..93bd7d525 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,11 +423,9 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[ - List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] - ] = None - has_more_rows: bool = False - results_queue: Optional[Any] = None + description: Optional[List[Tuple]] = None has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False + arrow_schema_bytes: Optional[bytes] = None + result_format: Optional[Any] = None diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 090ec255e..2054d01d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -48,7 +48,7 @@ def new(cls): is_staging_operation=False, command_id=None, has_been_closed_server_side=True, - has_more_rows=True, + is_direct_results=True, lz4_compressed=True, arrow_schema_bytes=b"schema", ) @@ -104,6 +104,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.staging_allowed_local_path = None + mock_backend.fetch_results.return_value = (Mock(), False) # Configure the decorator's mock to return our specific mock_backend mock_thrift_client_class.return_value = mock_backend @@ -184,6 +185,7 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( connection=mock_connection, @@ -210,6 +212,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) + mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -254,7 +257,10 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = ThriftResultSet(Mock(), Mock(), Mock()) + mock_backend = Mock() + mock_backend.fetch_results.return_value = (Mock(), False) + + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) @@ -472,7 +478,6 @@ def make_fake_row_slice(n_rows): mock_aq = Mock() mock_aq.next_n_rows.side_effect = make_fake_row_slice mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.execute("foo") diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7249a59e6..a649941e1 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -40,25 +40,30 @@ def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more schema, arrow_table = FetchTests.make_arrow_table(initial_results) arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + + # Create a mock backend that will return the queue when _fill_results_buffer is called + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + + num_cols = len(initial_results[0]) if initial_results else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=True, - has_more_rows=False, - description=Mock(), - lz4_compressed=Mock(), - results_queue=arrow_queue, + description=description, + lz4_compressed=True, is_staging_operation=False, ), - thrift_client=None, + thrift_client=mock_thrift_backend, + t_row_set=None, ) - num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ] return rs @staticmethod @@ -85,19 +90,19 @@ def fetch_results( mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 + description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] + rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( command_id=None, status=None, has_been_closed_server_side=False, - has_more_rows=True, - description=[ - (f"col{col_id}", "integer", None, None, None, None, None) - for col_id in range(num_cols) - ], - lz4_compressed=Mock(), - results_queue=None, + description=description, + lz4_compressed=True, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 7e025cf82..e4a9e5cdd 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -36,7 +36,7 @@ def make_dummy_result_set_from_initial_results(arrow_table): execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, - has_more_rows=False, + is_direct_results=False, description=Mock(), command_id=None, arrow_queue=arrow_queue, From 64e58b05415591a22feb4ab8ed52440c63be0d49 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:41:51 +0000 Subject: [PATCH 123/204] remove un-necessary changes in thrift backend tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_thrift_backend.py | 106 ++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 8274190fe..57b5e9b58 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -623,7 +623,10 @@ def test_handle_execute_response_sets_compression_in_direct_results( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=op_status, + operationStatus=ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -832,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value @@ -878,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) - + ( + execute_response, + _, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -947,8 +951,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) @@ -973,8 +983,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req + tcli_service_instance.GetOperationStatus.return_value = ( + ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + ) thrift_backend = self._make_fake_thrift_backend() - thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -988,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 @@ -1003,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( resultSetMetadata=self.metadata_resp, resultSet=ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, ), closeOperation=Mock(), @@ -1019,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( - execute_resp, Mock() - ) + ( + execute_response, + has_more_rows_result, + ) = thrift_backend._handle_execute_response(execute_resp, Mock()) - self.assertEqual(has_more_rows, execute_response.has_more_rows) + self.assertEqual(is_direct_results, has_more_rows_result) @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() @@ -1032,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): - for has_more_rows, resp_type in itertools.product( + for is_direct_results, resp_type in itertools.product( [True, False], self.execute_response_types ): - with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): + with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type): tcli_service_instance = tcli_service_class.return_value results_mock = MagicMock() results_mock.startRowOffset = 0 @@ -1048,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( fetch_results_resp = ttypes.TFetchResultsResp( status=self.okay_status, - hasMoreRows=has_more_rows, + hasMoreRows=is_direct_results, results=results_mock, resultSetMetadata=ttypes.TGetResultSetMetadataResp( resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET @@ -1081,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( description=Mock(), ) - self.assertEqual(has_more_rows, has_more_rows_resp) + self.assertEqual(is_direct_results, has_more_rows_resp) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): @@ -1136,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1151,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( "foo", Mock(), 100, 200, Mock(), cursor_mock ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.ExecuteStatement.call_args[0][0] @@ -1170,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1185,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetCatalogs.call_args[0][0] @@ -1201,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1216,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1228,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( schema_name="schema_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetSchemas.call_args[0][0] @@ -1241,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1256,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1270,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response( table_types=["type1", "type2"], ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetTables.call_args[0][0] @@ -1285,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( - self, tcli_service_class + self, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value response = Mock() @@ -1300,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) + thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1314,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response( column_name="column_pattern", ) # Verify the result is a ResultSet - self.assertIsInstance(result, ResultSet) + self.assertEqual(result, mock_result_set.return_value) # Check call to client req = tcli_service_instance.GetColumns.call_args[0][0] @@ -2203,14 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) + @patch("databricks.sql.backend.thrift_backend.ThriftResultSet") @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" ) def test_execute_command_sets_complex_type_fields_correctly( - self, mock_handle_execute_response, tcli_service_class + self, mock_handle_execute_response, tcli_service_class, mock_result_set ): tcli_service_instance = tcli_service_class.return_value + # Set up the mock to return a tuple with two values + mock_execute_response = Mock() + mock_arrow_schema = Mock() + mock_handle_execute_response.return_value = ( + mock_execute_response, + mock_arrow_schema, + ) + # Iterate through each possible combination of native types (True, False and unset) for complex, timestamp, decimals in itertools.product( [True, False, None], [True, False, None], [True, False, None] From 2903473a7f6ca72fa8400304cb002992a2471e6e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:43:37 +0000 Subject: [PATCH 124/204] remove unimplemented methods test Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 52 ---------------------------------- 1 file changed, 52 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 1d16763be..1434ed831 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -599,55 +599,3 @@ def test_utility_methods(self, sea_client): manifest_obj.schema = {} description = sea_client._extract_description_from_manifest(manifest_obj) assert description is None - - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" - ) - - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], - ) - - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", - ) From 021ff4ce733a568879b7f3f184bd5629ff22406c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:58:41 +0000 Subject: [PATCH 125/204] remove unimplemented method tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/thrift_backend.py | 2 - src/databricks/sql/result_set.py | 1 - src/databricks/sql/utils.py | 1 - tests/unit/test_sea_result_set.py | 77 -------------------- 4 files changed, 81 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index cab5e0052..02d335aa4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1232,8 +1232,6 @@ def fetch_results( ) ) - from databricks.sql.utils import ThriftResultSetQueueFactory - queue = ThriftResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 03f9895ce..49394b12a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -480,7 +480,6 @@ def __init__( str(execute_response.command_id.to_sea_statement_id()), description=execute_response.description, max_download_threads=sea_client.max_download_threads, - ssl_options=sea_client.ssl_options, sea_client=sea_client, lz4_compressed=execute_response.lz4_compressed, ) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d3f2d9ee3..ac855e30d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -119,7 +119,6 @@ def build_queue( description: Optional[List[Tuple[Any, ...]]] = None, schema_bytes: Optional[bytes] = None, max_download_threads: Optional[int] = None, - ssl_options: Optional[SSLOptions] = None, sea_client: Optional["SeaDatabricksClient"] = None, lz4_compressed: bool = False, ) -> ResultSetQueue: diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c596dbc14..f0049e3aa 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -122,80 +122,3 @@ def test_close_when_connection_closed( mock_sea_client.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that unimplemented methods raise NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - - with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", - ): - result_set.fetchmany_arrow(10) - - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test using the result set in a for loop - for row in result_set: - pass - - def test_fill_results_buffer_not_implemented( - self, mock_connection, mock_sea_client, execute_response - ): - """Test that _fill_results_buffer raises NotImplementedError.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", - ): - result_set._fill_results_buffer() From adecd5354899514e9355f40e2776991432fcea7b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 04:01:59 +0000 Subject: [PATCH 126/204] modify example scripts to include fetch calls Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 168 ++++++++----- examples/experimental/tests/__init__.py | 0 .../tests/test_sea_async_query.py | 236 ++++++++++++++++++ .../experimental/tests/test_sea_metadata.py | 98 ++++++++ .../experimental/tests/test_sea_session.py | 71 ++++++ .../experimental/tests/test_sea_sync_query.py | 176 +++++++++++++ 6 files changed, 693 insertions(+), 56 deletions(-) create mode 100644 examples/experimental/tests/__init__.py create mode 100644 examples/experimental/tests/test_sea_async_query.py create mode 100644 examples/experimental/tests/test_sea_metadata.py create mode 100644 examples/experimental/tests/test_sea_session.py create mode 100644 examples/experimental/tests/test_sea_sync_query.py diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1ab..6d72833d5 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -1,66 +1,122 @@ +""" +Main script to run all SEA connector tests. + +This script runs all the individual test modules and displays +a summary of test results with visual indicators. +""" import os import sys import logging -from databricks.sql.client import Connection +import subprocess +from typing import List, Tuple logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -def test_sea_session(): - """ - Test opening and closing a SEA session using the connector. - - This function connects to a Databricks SQL endpoint using the SEA backend, - opens a session, and then closes it. - - Required environment variables: - - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - - DATABRICKS_TOKEN: Personal access token for authentication - """ - - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - access_token = os.environ.get("DATABRICKS_TOKEN") - catalog = os.environ.get("DATABRICKS_CATALOG") - - if not all([server_hostname, http_path, access_token]): - logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") - sys.exit(1) - - logger.info(f"Connecting to {server_hostname}") - logger.info(f"HTTP Path: {http_path}") - if catalog: - logger.info(f"Using catalog: {catalog}") - - try: - logger.info("Creating connection with SEA backend...") - connection = Connection( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - catalog=catalog, - schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent +TEST_MODULES = [ + "test_sea_session", + "test_sea_sync_query", + "test_sea_async_query", + "test_sea_metadata", + "test_sea_multi_chunk", +] + + +def run_test_module(module_name: str) -> bool: + """Run a test module and return success status.""" + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" + ) + + # Handle the multi-chunk test which is in the main directory + if module_name == "test_sea_multi_chunk": + module_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), f"{module_name}.py" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") - logger.info(f"backend type: {type(connection.session.backend)}") - - # Close the connection - logger.info("Closing the SEA session...") - connection.close() - logger.info("Successfully closed SEA session") - - except Exception as e: - logger.error(f"Error testing SEA session: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - sys.exit(1) - - logger.info("SEA session test completed successfully") + + # Simply run the module as a script - each module handles its own test execution + result = subprocess.run( + [sys.executable, module_path], capture_output=True, text=True + ) + + # Log the output from the test module + if result.stdout: + for line in result.stdout.strip().split("\n"): + logger.info(line) + + if result.stderr: + for line in result.stderr.strip().split("\n"): + logger.error(line) + + return result.returncode == 0 + + +def run_tests() -> List[Tuple[str, bool]]: + """Run all tests and return results.""" + results = [] + + for module_name in TEST_MODULES: + try: + logger.info(f"\n{'=' * 50}") + logger.info(f"Running test: {module_name}") + logger.info(f"{'-' * 50}") + + success = run_test_module(module_name) + results.append((module_name, success)) + + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"Test {module_name}: {status}") + + except Exception as e: + logger.error(f"Error loading or running test {module_name}: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + results.append((module_name, False)) + + return results + + +def print_summary(results: List[Tuple[str, bool]]) -> None: + """Print a summary of test results.""" + logger.info(f"\n{'=' * 50}") + logger.info("TEST SUMMARY") + logger.info(f"{'-' * 50}") + + passed = sum(1 for _, success in results if success) + total = len(results) + + for module_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + logger.info(f"{status} - {module_name}") + + logger.info(f"{'-' * 50}") + logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}") + logger.info(f"{'=' * 50}") + if __name__ == "__main__": - test_sea_session() + # Check if required environment variables are set + required_vars = [ + "DATABRICKS_SERVER_HOSTNAME", + "DATABRICKS_HTTP_PATH", + "DATABRICKS_TOKEN", + ] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + + if missing_vars: + logger.error( + f"Missing required environment variables: {', '.join(missing_vars)}" + ) + logger.error("Please set these variables before running the tests.") + sys.exit(1) + + # Run all tests + results = run_tests() + + # Print summary + print_summary(results) + + # Exit with appropriate status code + all_passed = all(success for _, success in results) + sys.exit(0 if all_passed else 1) diff --git a/examples/experimental/tests/__init__.py b/examples/experimental/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py new file mode 100644 index 000000000..3a4de778c --- /dev/null +++ b/examples/experimental/tests/test_sea_async_query.py @@ -0,0 +1,236 @@ +""" +Test for SEA asynchronous query execution functionality. +""" +import os +import sys +import logging +import time +from databricks.sql.client import Connection +from databricks.sql.backend.types import CommandState + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_async_query_with_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 5000 + cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing asynchronous query with cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute_async(query) + logger.info( + "Asynchronous query submitted successfully with cloud fetch enabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + + logger.info("PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_without_cloud_fetch(): + """ + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for asynchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 + cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 100)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing asynchronous query without cloud fetch to generate {requested_row_count} rows" + ) + cursor.execute_async(query) + logger.info( + "Asynchronous query submitted successfully with cloud fetch disabled" + ) + + # Check query state + logger.info("Checking query state...") + while cursor.is_query_pending(): + logger.info("Query is still pending, waiting...") + time.sleep(1) + + logger.info("Query is no longer pending, getting results...") + cursor.get_async_execution_result() + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + + logger.info( + f"Requested {requested_row_count} rows, received {actual_row_count} rows" + ) + + # Verify total row count + if actual_row_count != requested_row_count: + logger.error( + f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" + ) + return False + + logger.info("PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_async_query_exec(): + """ + Run both asynchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch() + logger.info( + f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch() + logger.info( + f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_async_query_exec() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_metadata.py b/examples/experimental/tests/test_sea_metadata.py new file mode 100644 index 000000000..a200d97d3 --- /dev/null +++ b/examples/experimental/tests/test_sea_metadata.py @@ -0,0 +1,98 @@ +""" +Test for SEA metadata functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_metadata(): + """ + Test metadata operations using the SEA backend. + + This function connects to a Databricks SQL endpoint using the SEA backend, + and executes metadata operations like catalogs(), schemas(), tables(), and columns(). + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + if not catalog: + logger.error( + "DATABRICKS_CATALOG environment variable is required for metadata tests." + ) + return False + + try: + # Create connection + logger.info("Creating connection for metadata operations") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Test catalogs + cursor = connection.cursor() + logger.info("Fetching catalogs...") + cursor.catalogs() + logger.info("Successfully fetched catalogs") + + # Test schemas + logger.info(f"Fetching schemas for catalog '{catalog}'...") + cursor.schemas(catalog_name=catalog) + logger.info("Successfully fetched schemas") + + # Test tables + logger.info(f"Fetching tables for catalog '{catalog}', schema 'default'...") + cursor.tables(catalog_name=catalog, schema_name="default") + logger.info("Successfully fetched tables") + + # Test columns for a specific table + # Using a common table that should exist in most environments + logger.info( + f"Fetching columns for catalog '{catalog}', schema 'default', table 'customer'..." + ) + cursor.columns( + catalog_name=catalog, schema_name="default", table_name="customer" + ) + logger.info("Successfully fetched columns") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error during SEA metadata test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_metadata() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_session.py b/examples/experimental/tests/test_sea_session.py new file mode 100644 index 000000000..516c1bbb8 --- /dev/null +++ b/examples/experimental/tests/test_sea_session.py @@ -0,0 +1,71 @@ +""" +Test for SEA session management functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_session(): + """ + Test opening and closing a SEA session using the connector. + + This function connects to a Databricks SQL endpoint using the SEA backend, + opens a session, and then closes it. + + Required environment variables: + - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname + - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint + - DATABRICKS_TOKEN: Personal access token for authentication + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + logger.info("Creating connection with SEA backend...") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"Backend type: {type(connection.session.backend)}") + + # Close the connection + logger.info("Closing the SEA session...") + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error(f"Error testing SEA session: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + success = test_sea_session() + sys.exit(0 if success else 1) diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py new file mode 100644 index 000000000..c69a84b8a --- /dev/null +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -0,0 +1,176 @@ +""" +Test for SEA synchronous query execution functionality. +""" +import os +import sys +import logging +from databricks.sql.client import Connection + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def test_sea_sync_query_with_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch enabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query with cloud fetch enabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch enabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch enabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a query that generates large rows to force multiple chunks + requested_row_count = 10000 + cursor = connection.cursor() + query = f""" + SELECT + id, + concat('value_', repeat('a', 10000)) as test_value + FROM range(1, {requested_row_count} + 1) AS t(id) + """ + + logger.info( + f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows" + ) + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test with cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_without_cloud_fetch(): + """ + Test executing a query synchronously using the SEA backend with cloud fetch disabled. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a query with cloud fetch disabled, and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + return False + + try: + # Create connection with cloud fetch disabled + logger.info( + "Creating connection for synchronous query execution with cloud fetch disabled" + ) + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, + enable_query_result_lz4_compression=False, + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # For non-cloud fetch, use a smaller row count to avoid exceeding inline limits + requested_row_count = 100 + cursor = connection.cursor() + logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows") + cursor.execute( + "SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)" + ) + + results = [cursor.fetchone()] + results.extend(cursor.fetchmany(10)) + results.extend(cursor.fetchall()) + logger.info(f"{len(results)} rows retrieved against 100 requested") + + # Close resources + cursor.close() + connection.close() + logger.info("Successfully closed SEA session") + + return True + + except Exception as e: + logger.error( + f"Error during SEA synchronous query execution test without cloud fetch: {str(e)}" + ) + import traceback + + logger.error(traceback.format_exc()) + return False + + +def test_sea_sync_query_exec(): + """ + Run both synchronous query tests and return overall success. + """ + with_cloud_fetch_success = test_sea_sync_query_with_cloud_fetch() + logger.info( + f"Synchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}" + ) + + without_cloud_fetch_success = test_sea_sync_query_without_cloud_fetch() + logger.info( + f"Synchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}" + ) + + return with_cloud_fetch_success and without_cloud_fetch_success + + +if __name__ == "__main__": + success = test_sea_sync_query_exec() + sys.exit(0 if success else 1) From bfc1f013b61b50f17bc86e57fbe805ca93096d23 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 04:43:22 +0000 Subject: [PATCH 127/204] fix sea connector tests Signed-off-by: varun-edachali-dbx --- examples/experimental/sea_connector_test.py | 7 ----- .../tests/test_sea_async_query.py | 26 +++++++++++++------ .../experimental/tests/test_sea_sync_query.py | 8 ++++-- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index 6d72833d5..edd171b05 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -18,7 +18,6 @@ "test_sea_sync_query", "test_sea_async_query", "test_sea_metadata", - "test_sea_multi_chunk", ] @@ -28,12 +27,6 @@ def run_test_module(module_name: str) -> bool: os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py" ) - # Handle the multi-chunk test which is in the main directory - if module_name == "test_sea_multi_chunk": - module_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), f"{module_name}.py" - ) - # Simply run the module as a script - each module handles its own test execution result = subprocess.run( [sys.executable, module_path], capture_output=True, text=True diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 3a4de778c..f805834b4 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -77,24 +77,29 @@ def test_sea_async_query_with_cloud_fetch(): logger.info("Query is no longer pending, getting results...") cursor.get_async_execution_result() - + results = [cursor.fetchone()] results.extend(cursor.fetchmany(10)) results.extend(cursor.fetchall()) - logger.info(f"{len(results)} rows retrieved against 100 requested") + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) logger.info( f"Requested {requested_row_count} rows, received {actual_row_count} rows" ) - + # Verify total row count if actual_row_count != requested_row_count: logger.error( f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}" ) return False - - logger.info("PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly") + + logger.info( + "PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly" + ) # Close resources cursor.close() @@ -182,12 +187,15 @@ def test_sea_async_query_without_cloud_fetch(): results = [cursor.fetchone()] results.extend(cursor.fetchmany(10)) results.extend(cursor.fetchall()) - logger.info(f"{len(results)} rows retrieved against 100 requested") + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) logger.info( f"Requested {requested_row_count} rows, received {actual_row_count} rows" ) - + # Verify total row count if actual_row_count != requested_row_count: logger.error( @@ -195,7 +203,9 @@ def test_sea_async_query_without_cloud_fetch(): ) return False - logger.info("PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly") + logger.info( + "PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly" + ) # Close resources cursor.close() diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index c69a84b8a..bfb86b82b 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -6,7 +6,7 @@ import logging from databricks.sql.client import Connection -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -62,10 +62,14 @@ def test_sea_sync_query_with_cloud_fetch(): logger.info( f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows" ) + cursor.execute(query) results = [cursor.fetchone()] results.extend(cursor.fetchmany(10)) results.extend(cursor.fetchall()) - logger.info(f"{len(results)} rows retrieved against 100 requested") + actual_row_count = len(results) + logger.info( + f"{actual_row_count} rows retrieved against {requested_row_count} requested" + ) # Close resources cursor.close() From 0a2cdfd7a08fcf48db3eb80b475315e56f876921 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 03:43:37 +0000 Subject: [PATCH 128/204] remove unimplemented methods test Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 52 ---------------------------------- 1 file changed, 52 deletions(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 1d16763be..1434ed831 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -599,55 +599,3 @@ def test_utility_methods(self, sea_client): manifest_obj.schema = {} description = sea_client._extract_description_from_manifest(manifest_obj) assert description is None - - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" - ) - - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], - ) - - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", - ) From cd22389fcc12713ec0c24715001b9067f856242b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 17 Jun 2025 05:16:36 +0000 Subject: [PATCH 129/204] remove invalid import Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 373a1b6d1..24a8880af 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -29,7 +29,6 @@ from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.backend.types import ExecuteResponse -from databricks.sql.utils import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite From 5ab9bbe4fff28a60eb35439130a589b83375789b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:34:26 +0000 Subject: [PATCH 130/204] better align queries with JDBC impl Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 3b9d92151..49534ea16 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -645,7 +645,7 @@ def get_schemas( if not catalog_name: raise ValueError("Catalog name is required for get_schemas") - operation = f"SHOW SCHEMAS IN `{catalog_name}`" + operation = f"SHOW SCHEMAS IN {catalog_name}" if schema_name: operation += f" LIKE '{schema_name}'" @@ -683,7 +683,7 @@ def get_tables( operation = "SHOW TABLES IN " + ( "ALL CATALOGS" if catalog_name in [None, "*", "%"] - else f"CATALOG `{catalog_name}`" + else f"CATALOG {catalog_name}" ) if schema_name: @@ -706,7 +706,7 @@ def get_tables( ) assert result is not None, "execute_command returned None in synchronous mode" - # Apply client-side filtering by table_types if specified + # Apply client-side filtering by table_types from databricks.sql.backend.filters import ResultSetFilter result = ResultSetFilter.filter_tables_by_type(result, table_types) @@ -728,7 +728,7 @@ def get_columns( if not catalog_name: raise ValueError("Catalog name is required for get_columns") - operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + operation = f"SHOW COLUMNS IN CATALOG {catalog_name}" if schema_name: operation += f" SCHEMA LIKE '{schema_name}'" From 1ab6e8793b04c3065fbe49f9a42d6a3ddb83feed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:38:37 +0000 Subject: [PATCH 131/204] line breaks after multi-line PRs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index ec91d87da..2966f6797 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -49,6 +49,7 @@ def _filter_sea_result_set( Returns: A filtered SEA result set """ + # Get all remaining rows all_rows = result_set.results.remaining_rows() @@ -108,6 +109,7 @@ def filter_by_column_values( Returns: A filtered result set """ + # Convert to uppercase for case-insensitive comparison if needed if not case_sensitive: allowed_values = [v.upper() for v in allowed_values] @@ -154,6 +156,7 @@ def filter_tables_by_type( Returns: A filtered result set containing only tables of the specified types """ + # Default table types if none specified DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] valid_types = ( From f469c24c09f82b8d747d4b93b73fdf8380e7c0a5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 04:59:02 +0000 Subject: [PATCH 132/204] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 2966f6797..f8abe26e0 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -9,16 +9,11 @@ List, Optional, Any, - Dict, Callable, - TypeVar, - Generic, cast, TYPE_CHECKING, ) -from databricks.sql.backend.types import ExecuteResponse, CommandId -from databricks.sql.backend.sea.models.base import ResultData from databricks.sql.backend.sea.backend import SeaDatabricksClient if TYPE_CHECKING: From 68ec65f039695d4c98518d676b4ac0d53cf20600 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 05:03:04 +0000 Subject: [PATCH 133/204] fix: introduce ExecuteResponse import Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index f8abe26e0..b97787889 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -15,6 +15,7 @@ ) from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.types import ExecuteResponse if TYPE_CHECKING: from databricks.sql.result_set import ResultSet, SeaResultSet From f6d873dc68b6aa15ea53bdc9c54d6f5d4a7f0106 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 07:58:15 +0000 Subject: [PATCH 134/204] remove unimplemented metadata methods test, un-necessary imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 6 +-- .../sql/backend/sea/models/responses.py | 18 +++---- tests/unit/test_filters.py | 5 -- tests/unit/test_sea_backend.py | 53 +------------------ 4 files changed, 13 insertions(+), 69 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a48a97953..9d301d3bc 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -41,9 +41,9 @@ CreateSessionResponse, ) from databricks.sql.backend.sea.models.responses import ( - parse_status, - parse_manifest, - parse_result, + _parse_status, + _parse_manifest, + _parse_result, ) logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index dae37b1ae..0baf27ab2 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -18,7 +18,7 @@ ) -def parse_status(data: Dict[str, Any]) -> StatementStatus: +def _parse_status(data: Dict[str, Any]) -> StatementStatus: """Parse status from response data.""" status_data = data.get("status", {}) error = None @@ -40,7 +40,7 @@ def parse_status(data: Dict[str, Any]) -> StatementStatus: ) -def parse_manifest(data: Dict[str, Any]) -> ResultManifest: +def _parse_manifest(data: Dict[str, Any]) -> ResultManifest: """Parse manifest from response data.""" manifest_data = data.get("manifest", {}) @@ -69,7 +69,7 @@ def parse_manifest(data: Dict[str, Any]) -> ResultManifest: ) -def parse_result(data: Dict[str, Any]) -> ResultData: +def _parse_result(data: Dict[str, Any]) -> ResultData: """Parse result data from response data.""" result_data = data.get("result", {}) external_links = None @@ -118,9 +118,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": """Create an ExecuteStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=parse_status(data), - manifest=parse_manifest(data), - result=parse_result(data), + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) @@ -138,9 +138,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": """Create a GetStatementResponse from a dictionary.""" return cls( statement_id=data.get("statement_id", ""), - status=parse_status(data), - manifest=parse_manifest(data), - result=parse_result(data), + status=_parse_status(data), + manifest=_parse_manifest(data), + result=_parse_result(data), ) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 49bd1c328..d0b815b95 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -4,11 +4,6 @@ import unittest from unittest.mock import MagicMock, patch -import sys -from typing import List, Dict, Any - -# Add the necessary path to import the filter module -sys.path.append("/home/varun.edachali/conn/databricks-sql-python/src") from databricks.sql.backend.filters import ResultSetFilter diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f30c92ed0..af4742cb2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -631,55 +631,4 @@ def test_utility_methods(self, sea_client): assert ( sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" - ) - - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], - ) - - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", - ) + \ No newline at end of file From 28675f5c46c5233159d5b0456793ffa9a246d795 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 18 Jun 2025 08:28:27 +0000 Subject: [PATCH 135/204] introduce unit tests for metadata methods Signed-off-by: varun-edachali-dbx --- tests/unit/test_filters.py | 133 +++++++++++------ tests/unit/test_sea_backend.py | 253 ++++++++++++++++++++++++++++++++- 2 files changed, 342 insertions(+), 44 deletions(-) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index d0b815b95..bf8d30707 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -15,17 +15,31 @@ def setUp(self): """Set up test fixtures.""" # Create a mock SeaResultSet self.mock_sea_result_set = MagicMock() - self.mock_sea_result_set._response = { - "result": { - "data_array": [ - ["catalog1", "schema1", "table1", "TABLE", ""], - ["catalog1", "schema1", "table2", "VIEW", ""], - ["catalog1", "schema1", "table3", "SYSTEM TABLE", ""], - ["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""], - ], - "row_count": 4, - } - } + + # Set up the remaining_rows method on the results attribute + self.mock_sea_result_set.results = MagicMock() + self.mock_sea_result_set.results.remaining_rows.return_value = [ + ["catalog1", "schema1", "table1", "owner1", "2023-01-01", "TABLE", ""], + ["catalog1", "schema1", "table2", "owner1", "2023-01-01", "VIEW", ""], + [ + "catalog1", + "schema1", + "table3", + "owner1", + "2023-01-01", + "SYSTEM TABLE", + "", + ], + [ + "catalog1", + "schema1", + "table4", + "owner1", + "2023-01-01", + "EXTERNAL TABLE", + "", + ], + ] # Set up the connection and other required attributes self.mock_sea_result_set.connection = MagicMock() @@ -33,6 +47,7 @@ def setUp(self): self.mock_sea_result_set.buffer_size_bytes = 1000 self.mock_sea_result_set.arraysize = 100 self.mock_sea_result_set.statement_id = "test-statement-id" + self.mock_sea_result_set.lz4_compressed = False # Create a mock CommandId from databricks.sql.backend.types import CommandId, BackendType @@ -45,70 +60,102 @@ def setUp(self): ("catalog_name", "string", None, None, None, None, True), ("schema_name", "string", None, None, None, None, True), ("table_name", "string", None, None, None, None, True), + ("owner", "string", None, None, None, None, True), + ("creation_time", "string", None, None, None, None, True), ("table_type", "string", None, None, None, None, True), ("remarks", "string", None, None, None, None, True), ] self.mock_sea_result_set.has_been_closed_server_side = False + self.mock_sea_result_set._arrow_schema_bytes = None - def test_filter_tables_by_type(self): - """Test filtering tables by type.""" - # Test with specific table types - table_types = ["TABLE", "VIEW"] + def test_filter_by_column_values(self): + """Test filtering by column values with various options.""" + # Case 1: Case-sensitive filtering + allowed_values = ["table1", "table3"] - # Make the mock_sea_result_set appear to be a SeaResultSet with patch("databricks.sql.backend.filters.isinstance", return_value=True): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: - # Set up the mock to return a new mock when instantiated mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance - result = ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, table_types + # Call filter_by_column_values on the table_name column (index 2) + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, 2, allowed_values, case_sensitive=True ) # Verify the filter was applied correctly mock_sea_result_set_class.assert_called_once() - def test_filter_tables_by_type_case_insensitive(self): - """Test filtering tables by type with case insensitivity.""" - # Test with lowercase table types - table_types = ["table", "view"] + # Check the filtered data passed to the constructor + args, kwargs = mock_sea_result_set_class.call_args + result_data = kwargs.get("result_data") + self.assertIsNotNone(result_data) + self.assertEqual(len(result_data.data), 2) + self.assertIn(result_data.data[0][2], allowed_values) + self.assertIn(result_data.data[1][2], allowed_values) - # Make the mock_sea_result_set appear to be a SeaResultSet + # Case 2: Case-insensitive filtering + mock_sea_result_set_class.reset_mock() with patch("databricks.sql.backend.filters.isinstance", return_value=True): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: - # Set up the mock to return a new mock when instantiated mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance - result = ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, table_types + # Call filter_by_column_values with case-insensitive matching + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, + 2, + ["TABLE1", "TABLE3"], + case_sensitive=False, ) - - # Verify the filter was applied correctly mock_sea_result_set_class.assert_called_once() - def test_filter_tables_by_type_default(self): - """Test filtering tables by type with default types.""" - # Make the mock_sea_result_set appear to be a SeaResultSet - with patch("databricks.sql.backend.filters.isinstance", return_value=True): - with patch( - "databricks.sql.result_set.SeaResultSet" - ) as mock_sea_result_set_class: - # Set up the mock to return a new mock when instantiated - mock_instance = MagicMock() - mock_sea_result_set_class.return_value = mock_instance + # Case 3: Unsupported result set type + mock_unsupported_result_set = MagicMock() + with patch("databricks.sql.backend.filters.isinstance", return_value=False): + with patch("databricks.sql.backend.filters.logger") as mock_logger: + result = ResultSetFilter.filter_by_column_values( + mock_unsupported_result_set, 0, ["value"], True + ) + mock_logger.warning.assert_called_once() + self.assertEqual(result, mock_unsupported_result_set) + + def test_filter_tables_by_type(self): + """Test filtering tables by type with various options.""" + # Case 1: Specific table types + table_types = ["TABLE", "VIEW"] - result = ResultSetFilter.filter_tables_by_type( - self.mock_sea_result_set, None + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types ) + args, kwargs = mock_filter.call_args + self.assertEqual(args[0], self.mock_sea_result_set) + self.assertEqual(args[1], 5) # Table type column index + self.assertEqual(args[2], table_types) + self.assertEqual(kwargs.get("case_sensitive"), True) - # Verify the filter was applied correctly - mock_sea_result_set_class.assert_called_once() + # Case 2: Default table types (None or empty list) + with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + # Test with None + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + # Test with empty list + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) if __name__ == "__main__": diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index af4742cb2..d75359f2f 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -631,4 +631,255 @@ def test_utility_methods(self, sea_client): assert ( sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - \ No newline at end of file + + def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): + """Test the get_catalogs method.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call get_catalogs + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify execute_command was called with the correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result is correct + assert result == mock_result_set + + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): + """Test the get_schemas method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With catalog and schema names + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): + """Test the get_tables method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the filter_tables_by_type method + with patch( + "databricks.sql.backend.filters.ResultSetFilter.filter_tables_by_type", + return_value=mock_result_set, + ) as mock_filter: + # Case 1: With catalog name only + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, None) + + # Case 2: With all parameters + table_types = ["TABLE", "VIEW"] + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + table_types=table_types, + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog SCHEMA LIKE 'test_schema' LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, table_types) + + # Case 3: With wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 4: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): + """Test the get_columns method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With all parameters + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_columns" in str(excinfo.value) From 3578659af87df515addf8632d88549df769106d2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 13:56:15 +0530 Subject: [PATCH 136/204] remove verbosity in ResultSetFilter docstring Co-authored-by: jayant <167047871+jayantsing-db@users.noreply.github.com> --- src/databricks/sql/backend/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index b97787889..30f36f25c 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -25,7 +25,7 @@ class ResultSetFilter: """ - A general-purpose filter for result sets that can be applied to any backend. + A general-purpose filter for result sets. This class provides methods to filter result sets based on various criteria, similar to the client-side filtering in the JDBC connector. From 8713023df340c0f943ead5ba7578e6d686953e46 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 08:28:37 +0000 Subject: [PATCH 137/204] remove un-necessary info in ResultSetFilter docstring Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 30f36f25c..17a426596 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -26,9 +26,6 @@ class ResultSetFilter: """ A general-purpose filter for result sets. - - This class provides methods to filter result sets based on various criteria, - similar to the client-side filtering in the JDBC connector. """ @staticmethod From 22dc2522f0edfe43d5a7d2398ec487e229491526 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 08:33:39 +0000 Subject: [PATCH 138/204] remove explicit type checking, string literals around forward annotations Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/filters.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py index 17a426596..468fb4d4c 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/filters.py @@ -11,14 +11,12 @@ Any, Callable, cast, - TYPE_CHECKING, ) from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.types import ExecuteResponse -if TYPE_CHECKING: - from databricks.sql.result_set import ResultSet, SeaResultSet +from databricks.sql.result_set import ResultSet, SeaResultSet logger = logging.getLogger(__name__) @@ -30,8 +28,8 @@ class ResultSetFilter: @staticmethod def _filter_sea_result_set( - result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] - ) -> "SeaResultSet": + result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool] + ) -> SeaResultSet: """ Filter a SEA result set using the provided filter function. @@ -49,9 +47,6 @@ def _filter_sea_result_set( # Filter rows filtered_rows = [row for row in all_rows if filter_func(row)] - # Import SeaResultSet here to avoid circular imports - from databricks.sql.result_set import SeaResultSet - # Reuse the command_id from the original result set command_id = result_set.command_id @@ -67,10 +62,13 @@ def _filter_sea_result_set( ) # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData result_data = ResultData(data=filtered_rows, external_links=None) + from databricks.sql.result_set import SeaResultSet + # Create a new SeaResultSet with the filtered data filtered_result_set = SeaResultSet( connection=result_set.connection, @@ -85,11 +83,11 @@ def _filter_sea_result_set( @staticmethod def filter_by_column_values( - result_set: "ResultSet", + result_set: ResultSet, column_index: int, allowed_values: List[str], case_sensitive: bool = False, - ) -> "ResultSet": + ) -> ResultSet: """ Filter a result set by values in a specific column. @@ -133,8 +131,8 @@ def filter_by_column_values( @staticmethod def filter_tables_by_type( - result_set: "ResultSet", table_types: Optional[List[str]] = None - ) -> "ResultSet": + result_set: ResultSet, table_types: Optional[List[str]] = None + ) -> ResultSet: """ Filter a result set of tables by the specified table types. From 390f5928aca9b16c5b30b8a7eb292c3b4cd405dd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 08:56:37 +0000 Subject: [PATCH 139/204] house SQL commands in constants Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 27 ++++++++++--------- .../sql/backend/sea/utils/constants.py | 20 ++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 9d301d3bc..ac3644b2f 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -10,6 +10,7 @@ ResultDisposition, ResultCompression, WaitTimeout, + MetadataCommands, ) if TYPE_CHECKING: @@ -635,7 +636,7 @@ def get_catalogs( ) -> "ResultSet": """Get available catalogs by executing 'SHOW CATALOGS'.""" result = self.execute_command( - operation="SHOW CATALOGS", + operation=MetadataCommands.SHOW_CATALOGS.value, session_id=session_id, max_rows=max_rows, max_bytes=max_bytes, @@ -662,10 +663,10 @@ def get_schemas( if not catalog_name: raise ValueError("Catalog name is required for get_schemas") - operation = f"SHOW SCHEMAS IN {catalog_name}" + operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name) if schema_name: - operation += f" LIKE '{schema_name}'" + operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name) result = self.execute_command( operation=operation, @@ -697,17 +698,19 @@ def get_tables( if not catalog_name: raise ValueError("Catalog name is required for get_tables") - operation = "SHOW TABLES IN " + ( - "ALL CATALOGS" + operation = ( + MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value if catalog_name in [None, "*", "%"] - else f"CATALOG {catalog_name}" + else MetadataCommands.SHOW_TABLES.value.format( + MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name) + ) ) if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) if table_name: - operation += f" LIKE '{table_name}'" + operation += MetadataCommands.LIKE_PATTERN.value.format(table_name) result = self.execute_command( operation=operation, @@ -745,16 +748,16 @@ def get_columns( if not catalog_name: raise ValueError("Catalog name is required for get_columns") - operation = f"SHOW COLUMNS IN CATALOG {catalog_name}" + operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name) if schema_name: - operation += f" SCHEMA LIKE '{schema_name}'" + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) if table_name: - operation += f" TABLE LIKE '{table_name}'" + operation += MetadataCommands.TABLE_LIKE_PATTERN.value.format(table_name) if column_name: - operation += f" LIKE '{column_name}'" + operation += MetadataCommands.LIKE_PATTERN.value.format(column_name) result = self.execute_command( operation=operation, diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 7481a90db..4912455c9 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -45,3 +45,23 @@ class WaitTimeout(Enum): ASYNC = "0s" SYNC = "10s" + + +class MetadataCommands(Enum): + """SQL commands used in the SEA backend. + + These constants are used for metadata operations and other SQL queries + to ensure consistency and avoid string literal duplication. + """ + + SHOW_CATALOGS = "SHOW CATALOGS" + SHOW_SCHEMAS = "SHOW SCHEMAS IN {}" + SHOW_TABLES = "SHOW TABLES IN {}" + SHOW_TABLES_ALL_CATALOGS = "SHOW TABLES IN ALL CATALOGS" + SHOW_COLUMNS = "SHOW COLUMNS IN CATALOG {}" + + SCHEMA_LIKE_PATTERN = " SCHEMA LIKE '{}'" + TABLE_LIKE_PATTERN = " TABLE LIKE '{}'" + LIKE_PATTERN = " LIKE '{}'" + + CATALOG_SPECIFIC = "CATALOG {}" From 2712d1c218bf6577c26e6acbb2a9ddd8b0294203 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 23 Jun 2025 05:19:37 +0000 Subject: [PATCH 140/204] introduce unit tests for altered functionality Signed-off-by: varun-edachali-dbx --- tests/unit/test_json_queue.py | 137 +++++++ tests/unit/test_sea_result_set.py | 348 +++++++++++++++++- .../unit/test_sea_result_set_queue_factory.py | 87 +++++ 3 files changed, 570 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_json_queue.py create mode 100644 tests/unit/test_sea_result_set_queue_factory.py diff --git a/tests/unit/test_json_queue.py b/tests/unit/test_json_queue.py new file mode 100644 index 000000000..ee19a574f --- /dev/null +++ b/tests/unit/test_json_queue.py @@ -0,0 +1,137 @@ +""" +Tests for the JsonQueue class. + +This module contains tests for the JsonQueue class, which implements +a queue for JSON array data returned by the SEA backend. +""" + +import pytest +from databricks.sql.utils import JsonQueue + + +class TestJsonQueue: + """Test suite for the JsonQueue class.""" + + @pytest.fixture + def sample_data_array(self): + """Create a sample data array for testing.""" + return [ + [1, "value1"], + [2, "value2"], + [3, "value3"], + [4, "value4"], + [5, "value5"], + ] + + def test_init(self, sample_data_array): + """Test initializing JsonQueue with a data array.""" + queue = JsonQueue(sample_data_array) + assert queue.data_array == sample_data_array + assert queue.cur_row_index == 0 + assert queue.n_valid_rows == 5 + + def test_next_n_rows_partial(self, sample_data_array): + """Test getting a subset of rows.""" + queue = JsonQueue(sample_data_array) + rows = queue.next_n_rows(3) + + # Check that we got the first 3 rows + assert rows == sample_data_array[:3] + + # Check that the current row index was updated + assert queue.cur_row_index == 3 + + def test_next_n_rows_all(self, sample_data_array): + """Test getting all rows at once.""" + queue = JsonQueue(sample_data_array) + rows = queue.next_n_rows(10) # More than available + + # Check that we got all rows + assert rows == sample_data_array + + # Check that the current row index was updated + assert queue.cur_row_index == 5 + + def test_next_n_rows_empty(self): + """Test getting rows from an empty queue.""" + queue = JsonQueue([]) + rows = queue.next_n_rows(5) + + # Check that we got an empty list + assert rows == [] + + # Check that the current row index was not updated + assert queue.cur_row_index == 0 + + def test_next_n_rows_zero(self, sample_data_array): + """Test getting zero rows.""" + queue = JsonQueue(sample_data_array) + rows = queue.next_n_rows(0) + + # Check that we got an empty list + assert rows == [] + + # Check that the current row index was not updated + assert queue.cur_row_index == 0 + + def test_next_n_rows_sequential(self, sample_data_array): + """Test getting rows in multiple sequential calls.""" + queue = JsonQueue(sample_data_array) + + # Get first 2 rows + rows1 = queue.next_n_rows(2) + assert rows1 == sample_data_array[:2] + assert queue.cur_row_index == 2 + + # Get next 2 rows + rows2 = queue.next_n_rows(2) + assert rows2 == sample_data_array[2:4] + assert queue.cur_row_index == 4 + + # Get remaining rows + rows3 = queue.next_n_rows(2) + assert rows3 == sample_data_array[4:] + assert queue.cur_row_index == 5 + + def test_remaining_rows(self, sample_data_array): + """Test getting all remaining rows.""" + queue = JsonQueue(sample_data_array) + + # Get first 2 rows + queue.next_n_rows(2) + + # Get remaining rows + rows = queue.remaining_rows() + + # Check that we got the remaining rows + assert rows == sample_data_array[2:] + + # Check that the current row index was updated to the end + assert queue.cur_row_index == 5 + + def test_remaining_rows_empty(self): + """Test getting remaining rows from an empty queue.""" + queue = JsonQueue([]) + rows = queue.remaining_rows() + + # Check that we got an empty list + assert rows == [] + + # Check that the current row index was not updated + assert queue.cur_row_index == 0 + + def test_remaining_rows_after_all_consumed(self, sample_data_array): + """Test getting remaining rows after all rows have been consumed.""" + queue = JsonQueue(sample_data_array) + + # Consume all rows + queue.next_n_rows(10) + + # Try to get remaining rows + rows = queue.remaining_rows() + + # Check that we got an empty list + assert rows == [] + + # Check that the current row index was not updated + assert queue.cur_row_index == 5 diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index f0049e3aa..8c6b9ae3a 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -10,6 +10,8 @@ from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.types import CommandId, CommandState, BackendType +from databricks.sql.utils import JsonQueue +from databricks.sql.types import Row class TestSeaResultSet: @@ -20,12 +22,15 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True + connection.disable_pandas = False return connection @pytest.fixture def mock_sea_client(self): """Create a mock SEA client.""" - return Mock() + client = Mock() + client.max_download_threads = 10 + return client @pytest.fixture def execute_response(self): @@ -37,11 +42,27 @@ def execute_response(self): mock_response.is_direct_results = False mock_response.results_queue = None mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) + ("col1", "INT", None, None, None, None, None), + ("col2", "STRING", None, None, None, None, None), ] mock_response.is_staging_operation = False + mock_response.lz4_compressed = False + mock_response.arrow_schema_bytes = b"" return mock_response + @pytest.fixture + def mock_result_data(self): + """Create mock result data.""" + result_data = Mock() + result_data.data = [[1, "value1"], [2, "value2"], [3, "value3"]] + result_data.external_links = None + return result_data + + @pytest.fixture + def mock_manifest(self): + """Create a mock manifest.""" + return Mock() + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -63,6 +84,49 @@ def test_init_with_execute_response( assert result_set.arraysize == 100 assert result_set.description == execute_response.description + # Verify that a JsonQueue was created with empty data + assert isinstance(result_set.results, JsonQueue) + assert result_set.results.data_array == [] + + def test_init_with_result_data( + self, + mock_connection, + mock_sea_client, + execute_response, + mock_result_data, + mock_manifest, + ): + """Test initializing SeaResultSet with result data.""" + with patch( + "databricks.sql.result_set.SeaResultSetQueueFactory" + ) as mock_factory: + mock_queue = Mock(spec=JsonQueue) + mock_factory.build_queue.return_value = mock_queue + + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + manifest=mock_manifest, + ) + + # Verify that the factory was called with the correct arguments + mock_factory.build_queue.assert_called_once_with( + mock_result_data, + mock_manifest, + str(execute_response.command_id.to_sea_statement_id()), + description=execute_response.description, + max_download_threads=mock_sea_client.max_download_threads, + sea_client=mock_sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Verify that the queue was set correctly + assert result_set.results == mock_queue + def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" result_set = SeaResultSet( @@ -122,3 +186,283 @@ def test_close_when_connection_closed( mock_sea_client.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED + + def test_convert_json_table( + self, mock_connection, mock_sea_client, execute_response + ): + """Test converting JSON data to Row objects.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Sample data + data = [[1, "value1"], [2, "value2"]] + + # Convert to Row objects + rows = result_set._convert_json_table(data) + + # Check that we got Row objects with the correct values + assert len(rows) == 2 + assert isinstance(rows[0], Row) + assert rows[0].col1 == 1 + assert rows[0].col2 == "value1" + assert rows[1].col1 == 2 + assert rows[1].col2 == "value2" + + def test_convert_json_table_empty( + self, mock_connection, mock_sea_client, execute_response + ): + """Test converting empty JSON data.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Empty data + data = [] + + # Convert to Row objects + rows = result_set._convert_json_table(data) + + # Check that we got an empty list + assert rows == [] + + def test_convert_json_table_no_description( + self, mock_connection, mock_sea_client, execute_response + ): + """Test converting JSON data with no description.""" + execute_response.description = None + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Sample data + data = [[1, "value1"], [2, "value2"]] + + # Convert to Row objects + rows = result_set._convert_json_table(data) + + # Check that we got the original data + assert rows == data + + def test_fetchone( + self, mock_connection, mock_sea_client, execute_response, mock_result_data + ): + """Test fetching one row.""" + # Create a result set with data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + ) + + # Replace the results queue with a JsonQueue containing test data + result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + + # Fetch one row + row = result_set.fetchone() + + # Check that we got a Row object with the correct values + assert isinstance(row, Row) + assert row.col1 == 1 + assert row.col2 == "value1" + + # Check that the row index was updated + assert result_set._next_row_index == 1 + + def test_fetchone_empty(self, mock_connection, mock_sea_client, execute_response): + """Test fetching one row from an empty result set.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Fetch one row + row = result_set.fetchone() + + # Check that we got None + assert row is None + + def test_fetchmany( + self, mock_connection, mock_sea_client, execute_response, mock_result_data + ): + """Test fetching multiple rows.""" + # Create a result set with data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + ) + + # Replace the results queue with a JsonQueue containing test data + result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + + # Fetch two rows + rows = result_set.fetchmany(2) + + # Check that we got two Row objects with the correct values + assert len(rows) == 2 + assert isinstance(rows[0], Row) + assert rows[0].col1 == 1 + assert rows[0].col2 == "value1" + assert rows[1].col1 == 2 + assert rows[1].col2 == "value2" + + # Check that the row index was updated + assert result_set._next_row_index == 2 + + def test_fetchmany_negative_size( + self, mock_connection, mock_sea_client, execute_response + ): + """Test fetching with a negative size.""" + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Try to fetch with a negative size + with pytest.raises( + ValueError, match="size argument for fetchmany is -1 but must be >= 0" + ): + result_set.fetchmany(-1) + + def test_fetchall( + self, mock_connection, mock_sea_client, execute_response, mock_result_data + ): + """Test fetching all rows.""" + # Create a result set with data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + ) + + # Replace the results queue with a JsonQueue containing test data + result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + + # Fetch all rows + rows = result_set.fetchall() + + # Check that we got three Row objects with the correct values + assert len(rows) == 3 + assert isinstance(rows[0], Row) + assert rows[0].col1 == 1 + assert rows[0].col2 == "value1" + assert rows[1].col1 == 2 + assert rows[1].col2 == "value2" + assert rows[2].col1 == 3 + assert rows[2].col2 == "value3" + + # Check that the row index was updated + assert result_set._next_row_index == 3 + + def test_fetchmany_json( + self, mock_connection, mock_sea_client, execute_response, mock_result_data + ): + """Test fetching JSON data directly.""" + # Create a result set with data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + ) + + # Replace the results queue with a JsonQueue containing test data + result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + + # Fetch two rows as JSON + rows = result_set.fetchmany_json(2) + + # Check that we got the raw data + assert rows == [[1, "value1"], [2, "value2"]] + + # Check that the row index was updated + assert result_set._next_row_index == 2 + + def test_fetchall_json( + self, mock_connection, mock_sea_client, execute_response, mock_result_data + ): + """Test fetching all JSON data directly.""" + # Create a result set with data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + ) + + # Replace the results queue with a JsonQueue containing test data + result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + + # Fetch all rows as JSON + rows = result_set.fetchall_json() + + # Check that we got the raw data + assert rows == [[1, "value1"], [2, "value2"], [3, "value3"]] + + # Check that the row index was updated + assert result_set._next_row_index == 3 + + def test_iteration( + self, mock_connection, mock_sea_client, execute_response, mock_result_data + ): + """Test iterating over the result set.""" + # Create a result set with data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=mock_result_data, + ) + + # Replace the results queue with a JsonQueue containing test data + result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + + # Iterate over the result set + rows = list(result_set) + + # Check that we got three Row objects with the correct values + assert len(rows) == 3 + assert isinstance(rows[0], Row) + assert rows[0].col1 == 1 + assert rows[0].col2 == "value1" + assert rows[1].col1 == 2 + assert rows[1].col2 == "value2" + assert rows[2].col1 == 3 + assert rows[2].col2 == "value3" + + # Check that the row index was updated + assert result_set._next_row_index == 3 diff --git a/tests/unit/test_sea_result_set_queue_factory.py b/tests/unit/test_sea_result_set_queue_factory.py new file mode 100644 index 000000000..f72510afb --- /dev/null +++ b/tests/unit/test_sea_result_set_queue_factory.py @@ -0,0 +1,87 @@ +""" +Tests for the SeaResultSetQueueFactory class. + +This module contains tests for the SeaResultSetQueueFactory class, which builds +appropriate result set queues for the SEA backend. +""" + +import pytest +from unittest.mock import Mock, patch + +from databricks.sql.utils import SeaResultSetQueueFactory, JsonQueue +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + + +class TestSeaResultSetQueueFactory: + """Test suite for the SeaResultSetQueueFactory class.""" + + @pytest.fixture + def mock_result_data_with_json(self): + """Create a mock ResultData with JSON data.""" + result_data = Mock(spec=ResultData) + result_data.data = [[1, "value1"], [2, "value2"]] + result_data.external_links = None + return result_data + + @pytest.fixture + def mock_result_data_with_external_links(self): + """Create a mock ResultData with external links.""" + result_data = Mock(spec=ResultData) + result_data.data = None + result_data.external_links = ["link1", "link2"] + return result_data + + @pytest.fixture + def mock_result_data_empty(self): + """Create a mock ResultData with no data.""" + result_data = Mock(spec=ResultData) + result_data.data = None + result_data.external_links = None + return result_data + + @pytest.fixture + def mock_manifest(self): + """Create a mock manifest.""" + return Mock(spec=ResultManifest) + + def test_build_queue_with_json_data( + self, mock_result_data_with_json, mock_manifest + ): + """Test building a queue with JSON data.""" + queue = SeaResultSetQueueFactory.build_queue( + sea_result_data=mock_result_data_with_json, + manifest=mock_manifest, + statement_id="test-statement-id", + ) + + # Check that we got a JsonQueue + assert isinstance(queue, JsonQueue) + + # Check that the queue has the correct data + assert queue.data_array == mock_result_data_with_json.data + + def test_build_queue_with_external_links( + self, mock_result_data_with_external_links, mock_manifest + ): + """Test building a queue with external links.""" + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + SeaResultSetQueueFactory.build_queue( + sea_result_data=mock_result_data_with_external_links, + manifest=mock_manifest, + statement_id="test-statement-id", + ) + + def test_build_queue_with_empty_data(self, mock_result_data_empty, mock_manifest): + """Test building a queue with empty data.""" + queue = SeaResultSetQueueFactory.build_queue( + sea_result_data=mock_result_data_empty, + manifest=mock_manifest, + statement_id="test-statement-id", + ) + + # Check that we got a JsonQueue with empty data + assert isinstance(queue, JsonQueue) + assert queue.data_array == [] From 984e8eedfff3997f68a30bf5bcd4b75de2051c15 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 23 Jun 2025 05:47:22 +0000 Subject: [PATCH 141/204] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 49394b12a..c67e9b3f2 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING import logging -import time import pandas from databricks.sql.backend.sea.backend import SeaDatabricksClient @@ -17,9 +16,8 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row -from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError +from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.utils import ( ColumnTable, ColumnQueue, From 0ce144d2e81f173b11bbece6d0d5fa1ba8b9806d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 23 Jun 2025 06:40:58 +0000 Subject: [PATCH 142/204] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ac3644b2f..9fa425f34 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -41,11 +41,6 @@ GetStatementResponse, CreateSessionResponse, ) -from databricks.sql.backend.sea.models.responses import ( - _parse_status, - _parse_manifest, - _parse_result, -) logger = logging.getLogger(__name__) From 50cc1e2315cf2b5cb154b666d04b482deb6e9d8c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 03:29:24 +0000 Subject: [PATCH 143/204] run small queries with SEA during integration tests Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 144 ++++++++++++++++++------ tests/e2e/test_parameterized_queries.py | 70 +++++++++--- 2 files changed, 162 insertions(+), 52 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 8cfed7c28..dc3280263 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -196,10 +196,14 @@ def test_execute_async__long_running(self): assert result[0].asDict() == {"count(1)": 0} - def test_execute_async__small_result(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_execute_async__small_result(self, extra_params): small_result_query = "SELECT 1" - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute_async(small_result_query) ## Fake sleep for 5 secs @@ -328,8 +332,12 @@ def test_incorrect_query_throws_exception(self): cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") assert "table_234234234" in str(cm.value) - def test_create_table_will_return_empty_result_set(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_create_table_will_return_empty_result_set(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute( @@ -341,8 +349,12 @@ def test_create_table_will_return_empty_result_set(self): finally: cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) - def test_get_tables(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_get_tables(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] @@ -387,8 +399,12 @@ def test_get_tables(self): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - def test_get_columns(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_get_columns(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] @@ -474,8 +490,12 @@ def test_get_columns(self): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - def test_escape_single_quotes(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_escape_single_quotes(self, extra_params): + with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) # Test escape syntax directly cursor.execute( @@ -499,8 +519,12 @@ def test_escape_single_quotes(self): rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" - def test_get_schemas(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_get_schemas(self, extra_params): + with self.cursor(extra_params) as cursor: database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute("CREATE DATABASE IF NOT EXISTS {}".format(database_name)) @@ -517,8 +541,12 @@ def test_get_schemas(self): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - def test_get_catalogs(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_get_catalogs(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description @@ -527,10 +555,14 @@ def test_get_catalogs(self): ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") - def test_get_arrow(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_get_arrow(self, extra_params): # These tests are quite light weight as the arrow fetch methods are used internally # by everything else - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT * FROM range(10)") table_1 = cursor.fetchmany_arrow(1).to_pydict() assert table_1 == OrderedDict([("id", [0])]) @@ -538,16 +570,24 @@ def test_get_arrow(self): table_2 = cursor.fetchall_arrow().to_pydict() assert table_2 == OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])]) - def test_unicode(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_unicode(self, extra_params): unicode_str = "数据砖" - with self.cursor({}) as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT '{}'".format(unicode_str)) results = cursor.fetchall() assert len(results) == 1 and len(results[0]) == 1 assert results[0][0] == unicode_str - def test_cancel_during_execute(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_cancel_during_execute(self, extra_params): + with self.cursor(extra_params) as cursor: def execute_really_long_query(): cursor.execute( @@ -578,8 +618,12 @@ def execute_really_long_query(): assert len(cursor.fetchall()) == 3 @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_failure(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_can_execute_command_after_failure(self, extra_params): + with self.cursor(extra_params) as cursor: with pytest.raises(DatabaseError): cursor.execute("this is a sytnax error") @@ -589,8 +633,12 @@ def test_can_execute_command_after_failure(self): self.assertEqualRowValues(res, [[1]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_can_execute_command_after_success(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_can_execute_command_after_success(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1;") cursor.execute("SELECT 2;") @@ -602,8 +650,12 @@ def generate_multi_row_query(self): return query @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchone(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_fetchone(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -614,8 +666,12 @@ def test_fetchone(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchall(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_fetchall(self, extra_params): + with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() cursor.execute(query) @@ -624,8 +680,12 @@ def test_fetchall(self): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_when_stride_fits(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_fetchmany_when_stride_fits(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -633,8 +693,12 @@ def test_fetchmany_when_stride_fits(self): self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_fetchmany_in_excess(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_fetchmany_in_excess(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -642,8 +706,12 @@ def test_fetchmany_in_excess(self): self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - def test_iterator_api(self): - with self.cursor({}) as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_iterator_api(self, extra_params): + with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) @@ -803,8 +871,12 @@ def test_decimal_not_returned_as_strings_arrow(self): assert pyarrow.types.is_decimal(decimal_type) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_catalogs_returns_arrow_table(self): - with self.cursor() as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_catalogs_returns_arrow_table(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.catalogs() results = cursor.fetchall_arrow() assert isinstance(results, pyarrow.Table) diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 79def9b72..d7afa8ae5 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -404,9 +404,13 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) "Consider using native parameters." not in caplog.text ), "Log message should not be supressed" - def test_positional_native_params_with_defaults(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_positional_native_params_with_defaults(self, extra_params): query = "SELECT ? col" - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: result = cursor.execute(query, parameters=[1]).fetchone() assert result.col == 1 @@ -422,10 +426,15 @@ def test_positional_native_params_with_defaults(self): ["foo", "bar", "baz"], ), ) - def test_positional_native_multiple(self, params): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_positional_native_multiple(self, params, extra_params): query = "SELECT ? `foo`, ? `bar`, ? `baz`" - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + combined_params = {"use_inline_params": False, **extra_params} + with self.cursor(extra_params=combined_params) as cursor: result = cursor.execute(query, params).fetchone() expected = [i.value if isinstance(i, DbsqlParameterBase) else i for i in params] @@ -433,8 +442,12 @@ def test_positional_native_multiple(self, params): assert set(outcome) == set(expected) - def test_readme_example(self): - with self.cursor() as cursor: + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_readme_example(self, extra_params): + with self.cursor(extra_params) as cursor: result = cursor.execute( "SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"} ).fetchall() @@ -498,11 +511,16 @@ def test_native_recursive_complex_type( class TestInlineParameterSyntax(PySQLPytestTestCase): """The inline parameter approach uses pyformat markers""" - def test_params_as_dict(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_params_as_dict(self, extra_params): query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz" params = {"foo": 1, "bar": 2, "baz": 3} - with self.connection(extra_params={"use_inline_params": True}) as conn: + combined_params = {"use_inline_params": True, **extra_params} + with self.connection(extra_params=combined_params) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() @@ -510,7 +528,11 @@ def test_params_as_dict(self): assert result.bar == 2 assert result.baz == 3 - def test_params_as_sequence(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_params_as_sequence(self, extra_params): """One side-effect of ParamEscaper using Python string interpolation to inline the values is that it can work with "ordinal" parameters, but only if a user writes parameter markers that are not defined with PEP-249. This test exists to prove that it works in the ideal case. @@ -520,7 +542,8 @@ def test_params_as_sequence(self): query = "SELECT %s foo, %s bar, %s baz" params = (1, 2, 3) - with self.connection(extra_params={"use_inline_params": True}) as conn: + combined_params = {"use_inline_params": True, **extra_params} + with self.connection(extra_params=combined_params) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.foo == 1 @@ -540,7 +563,11 @@ def test_inline_ordinals_can_break_sql(self): ): cursor.execute(query, parameters=params) - def test_inline_named_dont_break_sql(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_inline_named_dont_break_sql(self, extra_params): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test just proves that's the case. @@ -550,17 +577,23 @@ def test_inline_named_dont_break_sql(self): SELECT col_1 FROM base WHERE col_1 LIKE CONCAT(%(one)s, 'onite') """ params = {"one": "%(one)s"} - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + combined_params = {"use_inline_params": True, **extra_params} + with self.cursor(extra_params=combined_params) as cursor: result = cursor.execute(query, parameters=params).fetchone() print("hello") - def test_native_ordinals_dont_break_sql(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_native_ordinals_dont_break_sql(self, extra_params): """This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal parameters work in native mode for the exact same query, if we use the right marker `?` """ query = "SELECT 'samsonite', ? WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + combined_params = {"use_inline_params": False, **extra_params} + with self.cursor(extra_params=combined_params) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.samsonite == "samsonite" @@ -576,13 +609,18 @@ def test_inline_like_wildcard_breaks(self): with pytest.raises(ValueError, match="unsupported format character"): result = cursor.execute(query, parameters=params).fetchone() - def test_native_like_wildcard_works(self): + @pytest.mark.parametrize("extra_params", [ + {}, + {"use_sea": True, "use_cloud_fetch": False} + ]) + def test_native_like_wildcard_works(self, extra_params): """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE wildcards work under the native approach. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" params = {"param": "bar"} - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + combined_params = {"use_inline_params": False, **extra_params} + with self.cursor(extra_params=combined_params) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.col == 1 From 242307aa656d390ea41c54eddf871417a0a0458b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 24 Jun 2025 10:59:15 +0000 Subject: [PATCH 144/204] run some tests for sea Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 12 +- tests/e2e/common/large_queries_mixin.py | 2 +- tests/e2e/test_driver.py | 259 ++++++++++++++++------ tests/e2e/test_parameterized_queries.py | 121 +++++++--- 4 files changed, 282 insertions(+), 112 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 9fa425f34..fa83c669e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -12,6 +12,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -403,7 +404,7 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[Dict[str, Any]], + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: @@ -437,9 +438,9 @@ def execute_command( for param in parameters: sea_parameters.append( StatementParameter( - name=param["name"], - value=param["value"], - type=param["type"] if "type" in param else None, + name=param.name, + value=param.value.stringValue, + type=param.type, ) ) @@ -690,9 +691,6 @@ def get_tables( table_types: Optional[List[str]] = None, ) -> "ResultSet": """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - operation = ( MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value if catalog_name in [None, "*", "%"] diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index ed8ac4574..f57377da4 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -87,7 +87,7 @@ def test_long_running_query(self): and asserts that the query completes successfully. """ minutes = 60 - min_duration = 5 * minutes + min_duration = 3 * minutes duration = -1 scale0 = 10000 diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index dc3280263..9d297f6ab 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -196,10 +196,6 @@ def test_execute_async__long_running(self): assert result[0].asDict() == {"count(1)": 0} - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) def test_execute_async__small_result(self, extra_params): small_result_query = "SELECT 1" @@ -332,10 +328,17 @@ def test_incorrect_query_throws_exception(self): cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") assert "table_234234234" in str(cm.value) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_create_table_will_return_empty_result_set(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -349,10 +352,17 @@ def test_create_table_will_return_empty_result_set(self, extra_params): finally: cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_get_tables(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -399,10 +409,17 @@ def test_get_tables(self, extra_params): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_get_columns(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -490,10 +507,17 @@ def test_get_columns(self, extra_params): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_escape_single_quotes(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -519,10 +543,17 @@ def test_escape_single_quotes(self, extra_params): rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_get_schemas(self, extra_params): with self.cursor(extra_params) as cursor: database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -541,10 +572,17 @@ def test_get_schemas(self, extra_params): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_get_catalogs(self, extra_params): with self.cursor(extra_params) as cursor: cursor.catalogs() @@ -555,10 +593,17 @@ def test_get_catalogs(self, extra_params): ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_get_arrow(self, extra_params): # These tests are quite light weight as the arrow fetch methods are used internally # by everything else @@ -570,10 +615,17 @@ def test_get_arrow(self, extra_params): table_2 = cursor.fetchall_arrow().to_pydict() assert table_2 == OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])]) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_unicode(self, extra_params): unicode_str = "数据砖" with self.cursor(extra_params) as cursor: @@ -582,10 +634,17 @@ def test_unicode(self, extra_params): assert len(results) == 1 and len(results[0]) == 1 assert results[0][0] == unicode_str - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_cancel_during_execute(self, extra_params): with self.cursor(extra_params) as cursor: @@ -618,10 +677,17 @@ def execute_really_long_query(): assert len(cursor.fetchall()) == 3 @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_can_execute_command_after_failure(self, extra_params): with self.cursor(extra_params) as cursor: with pytest.raises(DatabaseError): @@ -633,10 +699,17 @@ def test_can_execute_command_after_failure(self, extra_params): self.assertEqualRowValues(res, [[1]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_can_execute_command_after_success(self, extra_params): with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1;") @@ -650,10 +723,17 @@ def generate_multi_row_query(self): return query @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_fetchone(self, extra_params): with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() @@ -666,10 +746,17 @@ def test_fetchone(self, extra_params): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_fetchall(self, extra_params): with self.cursor(extra_params) as cursor: query = self.generate_multi_row_query() @@ -680,10 +767,17 @@ def test_fetchall(self, extra_params): assert cursor.fetchone() == None @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_fetchmany_when_stride_fits(self, extra_params): with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" @@ -693,10 +787,17 @@ def test_fetchmany_when_stride_fits(self, extra_params): self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_fetchmany_in_excess(self, extra_params): with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" @@ -706,10 +807,17 @@ def test_fetchmany_in_excess(self, extra_params): self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) @skipIf(pysql_has_version("<", "2"), "requires pysql v2") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_iterator_api(self, extra_params): with self.cursor(extra_params) as cursor: query = "SELECT * FROM range(4)" @@ -871,10 +979,17 @@ def test_decimal_not_returned_as_strings_arrow(self): assert pyarrow.types.is_decimal(decimal_type) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_catalogs_returns_arrow_table(self, extra_params): with self.cursor(extra_params) as cursor: cursor.catalogs() diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index d7afa8ae5..e696c667b 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from decimal import Decimal from enum import Enum +import json from typing import Dict, List, Type, Union from unittest.mock import patch @@ -404,10 +405,17 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) "Consider using native parameters." not in caplog.text ), "Log message should not be supressed" - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_positional_native_params_with_defaults(self, extra_params): query = "SELECT ? col" with self.cursor(extra_params) as cursor: @@ -426,10 +434,17 @@ def test_positional_native_params_with_defaults(self, extra_params): ["foo", "bar", "baz"], ), ) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_positional_native_multiple(self, params, extra_params): query = "SELECT ? `foo`, ? `bar`, ? `baz`" @@ -442,10 +457,17 @@ def test_positional_native_multiple(self, params, extra_params): assert set(outcome) == set(expected) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_readme_example(self, extra_params): with self.cursor(extra_params) as cursor: result = cursor.execute( @@ -511,10 +533,17 @@ def test_native_recursive_complex_type( class TestInlineParameterSyntax(PySQLPytestTestCase): """The inline parameter approach uses pyformat markers""" - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_params_as_dict(self, extra_params): query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz" params = {"foo": 1, "bar": 2, "baz": 3} @@ -528,10 +557,17 @@ def test_params_as_dict(self, extra_params): assert result.bar == 2 assert result.baz == 3 - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_params_as_sequence(self, extra_params): """One side-effect of ParamEscaper using Python string interpolation to inline the values is that it can work with "ordinal" parameters, but only if a user writes parameter markers @@ -563,10 +599,17 @@ def test_inline_ordinals_can_break_sql(self): ): cursor.execute(query, parameters=params) - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_inline_named_dont_break_sql(self, extra_params): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test @@ -582,10 +625,17 @@ def test_inline_named_dont_break_sql(self, extra_params): result = cursor.execute(query, parameters=params).fetchone() print("hello") - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_native_ordinals_dont_break_sql(self, extra_params): """This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal parameters work in native mode for the exact same query, if we use the right marker `?` @@ -609,10 +659,17 @@ def test_inline_like_wildcard_breaks(self): with pytest.raises(ValueError, match="unsupported format character"): result = cursor.execute(query, parameters=params).fetchone() - @pytest.mark.parametrize("extra_params", [ - {}, - {"use_sea": True, "use_cloud_fetch": False} - ]) + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_native_like_wildcard_works(self, extra_params): """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE wildcards work under the native approach. From 35f1ef0eb40928d4c92b4b69312acf603c95dcd8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 01:56:46 +0000 Subject: [PATCH 145/204] remove catalog requirement in get_tables Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 3 --- src/databricks/sql/backend/sea/utils/constants.py | 4 ++-- tests/unit/test_sea_backend.py | 10 ---------- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ac3644b2f..53679d10e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -695,9 +695,6 @@ def get_tables( table_types: Optional[List[str]] = None, ) -> "ResultSet": """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" - if not catalog_name: - raise ValueError("Catalog name is required for get_tables") - operation = ( MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value if catalog_name in [None, "*", "%"] diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 4912455c9..402da0de5 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -60,8 +60,8 @@ class MetadataCommands(Enum): SHOW_TABLES_ALL_CATALOGS = "SHOW TABLES IN ALL CATALOGS" SHOW_COLUMNS = "SHOW COLUMNS IN CATALOG {}" - SCHEMA_LIKE_PATTERN = " SCHEMA LIKE '{}'" - TABLE_LIKE_PATTERN = " TABLE LIKE '{}'" LIKE_PATTERN = " LIKE '{}'" + SCHEMA_LIKE_PATTERN = " SCHEMA" + LIKE_PATTERN + TABLE_LIKE_PATTERN = " TABLE" + LIKE_PATTERN CATALOG_SPECIFIC = "CATALOG {}" diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index d75359f2f..e6c8734d0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -810,16 +810,6 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): enforce_embedded_schema_correctness=False, ) - # Case 4: Without catalog name (should raise ValueError) - with pytest.raises(ValueError) as excinfo: - sea_client.get_tables( - session_id=sea_session_id, - max_rows=100, - max_bytes=1000, - cursor=mock_cursor, - ) - assert "Catalog name is required for get_tables" in str(excinfo.value) - def test_get_columns(self, sea_client, sea_session_id, mock_cursor): """Test the get_columns method with various parameter combinations.""" # Mock the execute_command method From a515d260992b7902b017daf152b1c04c86c3d46d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 05:37:46 +0000 Subject: [PATCH 146/204] move filters.py to SEA utils Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- .../sql/backend/{ => sea/utils}/filters.py | 42 +++++++------------ tests/unit/test_filters.py | 28 ++++++------- tests/unit/test_sea_backend.py | 2 +- 4 files changed, 31 insertions(+), 43 deletions(-) rename src/databricks/sql/backend/{ => sea/utils}/filters.py (80%) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 53679d10e..e6d9a082e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -724,7 +724,7 @@ def get_tables( assert result is not None, "execute_command returned None in synchronous mode" # Apply client-side filtering by table_types - from databricks.sql.backend.filters import ResultSetFilter + from databricks.sql.backend.sea.utils.filters import ResultSetFilter result = ResultSetFilter.filter_tables_by_type(result, table_types) diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/sea/utils/filters.py similarity index 80% rename from src/databricks/sql/backend/filters.py rename to src/databricks/sql/backend/sea/utils/filters.py index 468fb4d4c..493975433 100644 --- a/src/databricks/sql/backend/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -83,11 +83,11 @@ def _filter_sea_result_set( @staticmethod def filter_by_column_values( - result_set: ResultSet, + result_set: SeaResultSet, column_index: int, allowed_values: List[str], case_sensitive: bool = False, - ) -> ResultSet: + ) -> SeaResultSet: """ Filter a result set by values in a specific column. @@ -105,34 +105,24 @@ def filter_by_column_values( if not case_sensitive: allowed_values = [v.upper() for v in allowed_values] - # Determine the type of result set and apply appropriate filtering - from databricks.sql.result_set import SeaResultSet - - if isinstance(result_set, SeaResultSet): - return ResultSetFilter._filter_sea_result_set( - result_set, - lambda row: ( - len(row) > column_index - and isinstance(row[column_index], str) - and ( - row[column_index].upper() - if not case_sensitive - else row[column_index] - ) - in allowed_values - ), - ) - - # For other result set types, return the original (should be handled by specific implementations) - logger.warning( - f"Filtering not implemented for result set type: {type(result_set).__name__}" + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), ) - return result_set @staticmethod def filter_tables_by_type( - result_set: ResultSet, table_types: Optional[List[str]] = None - ) -> ResultSet: + result_set: SeaResultSet, table_types: Optional[List[str]] = None + ) -> SeaResultSet: """ Filter a result set of tables by the specified table types. diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index bf8d30707..975376e13 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -5,7 +5,7 @@ import unittest from unittest.mock import MagicMock, patch -from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.backend.sea.utils.filters import ResultSetFilter class TestResultSetFilter(unittest.TestCase): @@ -73,7 +73,9 @@ def test_filter_by_column_values(self): # Case 1: Case-sensitive filtering allowed_values = ["table1", "table3"] - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: @@ -98,7 +100,9 @@ def test_filter_by_column_values(self): # Case 2: Case-insensitive filtering mock_sea_result_set_class.reset_mock() - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch( "databricks.sql.result_set.SeaResultSet" ) as mock_sea_result_set_class: @@ -114,22 +118,14 @@ def test_filter_by_column_values(self): ) mock_sea_result_set_class.assert_called_once() - # Case 3: Unsupported result set type - mock_unsupported_result_set = MagicMock() - with patch("databricks.sql.backend.filters.isinstance", return_value=False): - with patch("databricks.sql.backend.filters.logger") as mock_logger: - result = ResultSetFilter.filter_by_column_values( - mock_unsupported_result_set, 0, ["value"], True - ) - mock_logger.warning.assert_called_once() - self.assertEqual(result, mock_unsupported_result_set) - def test_filter_tables_by_type(self): """Test filtering tables by type with various options.""" # Case 1: Specific table types table_types = ["TABLE", "VIEW"] - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch.object( ResultSetFilter, "filter_by_column_values" ) as mock_filter: @@ -143,7 +139,9 @@ def test_filter_tables_by_type(self): self.assertEqual(kwargs.get("case_sensitive"), True) # Case 2: Default table types (None or empty list) - with patch("databricks.sql.backend.filters.isinstance", return_value=True): + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): with patch.object( ResultSetFilter, "filter_by_column_values" ) as mock_filter: diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index e6c8734d0..2d45a1f49 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -735,7 +735,7 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor): ) as mock_execute: # Mock the filter_tables_by_type method with patch( - "databricks.sql.backend.filters.ResultSetFilter.filter_tables_by_type", + "databricks.sql.backend.sea.utils.filters.ResultSetFilter.filter_tables_by_type", return_value=mock_result_set, ) as mock_filter: # Case 1: With catalog name only From 59b1330f2db8e680bce7b17b0941e39699b93cf2 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 05:40:23 +0000 Subject: [PATCH 147/204] ensure SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index e6d9a082e..623979115 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -12,6 +12,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.result_set import SeaResultSet if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -722,6 +723,9 @@ def get_tables( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" + assert isinstance( + result, SeaResultSet + ), "SEA backend execute_command returned a non-SeaResultSet" # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter From dd40bebff73442eedfd264192dc05376a7f86bed Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 06:13:43 +0000 Subject: [PATCH 148/204] prevent circular imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 +++----- src/databricks/sql/backend/sea/utils/filters.py | 11 +++++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 623979115..2af77ec45 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,7 +1,7 @@ import logging import time import re -from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set, cast from databricks.sql.backend.sea.models.base import ResultManifest from databricks.sql.backend.sea.utils.constants import ( @@ -12,7 +12,6 @@ WaitTimeout, MetadataCommands, ) -from databricks.sql.result_set import SeaResultSet if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -723,13 +722,12 @@ def get_tables( enforce_embedded_schema_correctness=False, ) assert result is not None, "execute_command returned None in synchronous mode" - assert isinstance( - result, SeaResultSet - ), "SEA backend execute_command returned a non-SeaResultSet" # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter + from databricks.sql.result_set import SeaResultSet + result = cast(SeaResultSet, result) result = ResultSetFilter.filter_tables_by_type(result, table_types) return result diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 493975433..db6a12e16 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -4,6 +4,8 @@ This module provides filtering capabilities for result sets returned by different backends. """ +from __future__ import annotations + import logging from typing import ( List, @@ -11,12 +13,13 @@ Any, Callable, cast, + TYPE_CHECKING, ) -from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import ExecuteResponse +if TYPE_CHECKING: + from databricks.sql.result_set import SeaResultSet -from databricks.sql.result_set import ResultSet, SeaResultSet +from databricks.sql.backend.types import ExecuteResponse logger = logging.getLogger(__name__) @@ -62,11 +65,11 @@ def _filter_sea_result_set( ) # Create a new ResultData object with filtered data - from databricks.sql.backend.sea.models.base import ResultData result_data = ResultData(data=filtered_rows, external_links=None) + from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data From 14057acb8e3201574b8a2054eb63506d7d894800 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 06:46:16 +0000 Subject: [PATCH 149/204] remove unused imports Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2af77ec45..b5385d5df 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -41,11 +41,6 @@ GetStatementResponse, CreateSessionResponse, ) -from databricks.sql.backend.sea.models.responses import ( - _parse_status, - _parse_manifest, - _parse_result, -) logger = logging.getLogger(__name__) From a4d5bdb726aee53bfa27b60e1b7baf78c01a67d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 06:51:59 +0000 Subject: [PATCH 150/204] remove cast, throw error if not SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 10 +++++++--- tests/unit/test_sea_backend.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index b5385d5df..2cd1c98c2 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,7 +1,7 @@ import logging import time import re -from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set, cast +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.models.base import ResultManifest from databricks.sql.backend.sea.utils.constants import ( @@ -718,11 +718,15 @@ def get_tables( ) assert result is not None, "execute_command returned None in synchronous mode" + from databricks.sql.result_set import SeaResultSet + + assert isinstance( + result, SeaResultSet + ), "execute_command returned a non-SeaResultSet" + # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter - from databricks.sql.result_set import SeaResultSet - result = cast(SeaResultSet, result) result = ResultSetFilter.filter_tables_by_type(result, table_types) return result diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 2d45a1f49..68dea3d81 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -729,7 +729,10 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method - mock_result_set = Mock() + from databricks.sql.result_set import SeaResultSet + + mock_result_set = Mock(spec=SeaResultSet) + with patch.object( sea_client, "execute_command", return_value=mock_result_set ) as mock_execute: From eb1a9b44f88d14558ef2890d841e9eb196f94bc7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 07:09:30 +0000 Subject: [PATCH 151/204] pass param as TSparkParameterValue Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 68dea3d81..ff5ae3976 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -354,7 +354,11 @@ def test_command_execution_advanced( "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - param = {"name": "param1", "value": "value1", "type": "STRING"} + param = Mock() + param.name = "param1" + param.value = Mock() + param.value.stringValue = "value1" + param.type = "STRING" with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( From e9b1314e28c2898f4d9c32defcf7042d4eb1fada Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 09:33:54 +0000 Subject: [PATCH 152/204] make SEA backend methods return SeaResultSet Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2cd1c98c2..83255f79b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import time import re @@ -15,7 +17,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet + from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -401,12 +403,12 @@ def execute_command( max_rows: int, max_bytes: int, lz4_compression: bool, - cursor: "Cursor", + cursor: Cursor, use_cloud_fetch: bool, parameters: List[Dict[str, Any]], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: + ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. @@ -573,8 +575,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState: def get_execution_result( self, command_id: CommandId, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> SeaResultSet: """ Get the result of a command execution. @@ -583,7 +585,7 @@ def get_execution_result( cursor: Cursor executing the command Returns: - ResultSet: A SeaResultSet instance with the execution results + SeaResultSet: A SeaResultSet instance with the execution results Raises: ValueError: If the command ID is invalid @@ -627,8 +629,8 @@ def get_catalogs( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> SeaResultSet: """Get available catalogs by executing 'SHOW CATALOGS'.""" result = self.execute_command( operation=MetadataCommands.SHOW_CATALOGS.value, @@ -650,10 +652,10 @@ def get_schemas( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ) -> "ResultSet": + ) -> SeaResultSet: """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" if not catalog_name: raise ValueError("Catalog name is required for get_schemas") @@ -683,12 +685,12 @@ def get_tables( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ) -> "ResultSet": + ) -> SeaResultSet: """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" operation = ( MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value @@ -718,12 +720,6 @@ def get_tables( ) assert result is not None, "execute_command returned None in synchronous mode" - from databricks.sql.result_set import SeaResultSet - - assert isinstance( - result, SeaResultSet - ), "execute_command returned a non-SeaResultSet" - # Apply client-side filtering by table_types from databricks.sql.backend.sea.utils.filters import ResultSetFilter @@ -736,12 +732,12 @@ def get_columns( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ) -> "ResultSet": + ) -> SeaResultSet: """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" if not catalog_name: raise ValueError("Catalog name is required for get_columns") From 8ede414f8ac485f4e9ed83b49af7087b106d0175 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 09:38:33 +0000 Subject: [PATCH 153/204] use spec-aligned Exceptions in SEA backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 26 +++++++++++------------ tests/unit/test_sea_backend.py | 17 ++++++++------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 83255f79b..bfc0c6c9e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -27,7 +27,7 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import DatabaseError, ServerOperationError +from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions @@ -172,7 +172,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: f"Note: SEA only works for warehouses." ) logger.error(error_message) - raise ValueError(error_message) + raise ProgrammingError(error_message) @property def max_download_threads(self) -> int: @@ -244,14 +244,14 @@ def close_session(self, session_id: SessionId) -> None: session_id: The session identifier returned by open_session() Raises: - ValueError: If the session ID is invalid + ProgrammingError: If the session ID is invalid OperationalError: If there's an error closing the session """ logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") + raise ProgrammingError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() request_data = DeleteSessionRequest( @@ -429,7 +429,7 @@ def execute_command( """ if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") + raise ProgrammingError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -504,11 +504,11 @@ def cancel_command(self, command_id: CommandId) -> None: command_id: Command identifier to cancel Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -527,11 +527,11 @@ def close_command(self, command_id: CommandId) -> None: command_id: Command identifier to close Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -553,7 +553,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: CommandState: The current state of the command Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: @@ -592,7 +592,7 @@ def get_execution_result( """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -658,7 +658,7 @@ def get_schemas( ) -> SeaResultSet: """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" if not catalog_name: - raise ValueError("Catalog name is required for get_schemas") + raise DatabaseError("Catalog name is required for get_schemas") operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name) @@ -740,7 +740,7 @@ def get_columns( ) -> SeaResultSet: """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" if not catalog_name: - raise ValueError("Catalog name is required for get_columns") + raise DatabaseError("Catalog name is required for get_columns") operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 68dea3d81..6847cded0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -18,6 +18,7 @@ from databricks.sql.exc import ( Error, NotSupportedError, + ProgrammingError, ServerOperationError, DatabaseError, ) @@ -129,7 +130,7 @@ def test_initialization(self, mock_http_client): assert client3.max_download_threads == 5 # Test with invalid HTTP path - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", port=443, @@ -195,7 +196,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i ) # Test close_session with invalid ID type - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) @@ -244,7 +245,7 @@ def test_command_execution_sync( assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: mock_thrift_handle = MagicMock() mock_thrift_handle.sessionId.guid = b"guid" mock_thrift_handle.sessionId.secret = b"secret" @@ -448,7 +449,7 @@ def test_command_management( ) # Test cancel_command with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.cancel_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -462,7 +463,7 @@ def test_command_management( ) # Test close_command with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.close_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -521,7 +522,7 @@ def test_command_management( assert result.status == CommandState.SUCCEEDED # Test get_execution_result with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -717,7 +718,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): ) # Case 3: Without catalog name (should raise ValueError) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(DatabaseError) as excinfo: sea_client.get_schemas( session_id=sea_session_id, max_rows=100, @@ -868,7 +869,7 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): ) # Case 3: Without catalog name (should raise ValueError) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(DatabaseError) as excinfo: sea_client.get_columns( session_id=sea_session_id, max_rows=100, From 09a1b11865ef9bad7d0ae5e510aede2b375f1beb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Thu, 26 Jun 2025 09:51:38 +0000 Subject: [PATCH 154/204] remove defensive row type check Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/filters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index db6a12e16..1b7660829 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -112,7 +112,6 @@ def filter_by_column_values( result_set, lambda row: ( len(row) > column_index - and isinstance(row[column_index], str) and ( row[column_index].upper() if not case_sensitive From 21c389d5f3fa5e61d475ddc2a11a78838e21288a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 03:27:22 +0000 Subject: [PATCH 155/204] introduce type conversion for primitive types for JSON + INLINE Signed-off-by: varun-edachali-dbx --- src/databricks/sql/conversion.py | 201 +++++++++++++++++++++++++++++ src/databricks/sql/result_set.py | 32 ++++- tests/unit/test_type_conversion.py | 161 +++++++++++++++++++++++ 3 files changed, 392 insertions(+), 2 deletions(-) create mode 100644 src/databricks/sql/conversion.py create mode 100644 tests/unit/test_type_conversion.py diff --git a/src/databricks/sql/conversion.py b/src/databricks/sql/conversion.py new file mode 100644 index 000000000..f6f98242f --- /dev/null +++ b/src/databricks/sql/conversion.py @@ -0,0 +1,201 @@ +""" +Type conversion utilities for the Databricks SQL Connector. + +This module provides functionality to convert string values from SEA Inline results +to appropriate Python types based on column metadata. +""" + +import datetime +import decimal +import logging +from dateutil import parser +from typing import Any, Callable, Dict, Optional, Union + +logger = logging.getLogger(__name__) + + +class SqlType: + """SQL type constants for improved maintainability.""" + + # Numeric types + TINYINT = "tinyint" + SMALLINT = "smallint" + INT = "int" + INTEGER = "integer" + BIGINT = "bigint" + FLOAT = "float" + REAL = "real" + DOUBLE = "double" + DECIMAL = "decimal" + NUMERIC = "numeric" + + # Boolean types + BOOLEAN = "boolean" + BIT = "bit" + + # Date/Time types + DATE = "date" + TIME = "time" + TIMESTAMP = "timestamp" + TIMESTAMP_NTZ = "timestamp_ntz" + TIMESTAMP_LTZ = "timestamp_ltz" + TIMESTAMP_TZ = "timestamp_tz" + + # String types + CHAR = "char" + VARCHAR = "varchar" + STRING = "string" + TEXT = "text" + + # Binary types + BINARY = "binary" + VARBINARY = "varbinary" + + # Complex types + ARRAY = "array" + MAP = "map" + STRUCT = "struct" + + @classmethod + def is_numeric(cls, sql_type: str) -> bool: + """Check if the SQL type is a numeric type.""" + return sql_type.lower() in ( + cls.TINYINT, + cls.SMALLINT, + cls.INT, + cls.INTEGER, + cls.BIGINT, + cls.FLOAT, + cls.REAL, + cls.DOUBLE, + cls.DECIMAL, + cls.NUMERIC, + ) + + @classmethod + def is_boolean(cls, sql_type: str) -> bool: + """Check if the SQL type is a boolean type.""" + return sql_type.lower() in (cls.BOOLEAN, cls.BIT) + + @classmethod + def is_datetime(cls, sql_type: str) -> bool: + """Check if the SQL type is a date/time type.""" + return sql_type.lower() in ( + cls.DATE, + cls.TIME, + cls.TIMESTAMP, + cls.TIMESTAMP_NTZ, + cls.TIMESTAMP_LTZ, + cls.TIMESTAMP_TZ, + ) + + @classmethod + def is_string(cls, sql_type: str) -> bool: + """Check if the SQL type is a string type.""" + return sql_type.lower() in (cls.CHAR, cls.VARCHAR, cls.STRING, cls.TEXT) + + @classmethod + def is_binary(cls, sql_type: str) -> bool: + """Check if the SQL type is a binary type.""" + return sql_type.lower() in (cls.BINARY, cls.VARBINARY) + + @classmethod + def is_complex(cls, sql_type: str) -> bool: + """Check if the SQL type is a complex type.""" + sql_type = sql_type.lower() + return ( + sql_type.startswith(cls.ARRAY) + or sql_type.startswith(cls.MAP) + or sql_type.startswith(cls.STRUCT) + ) + + +class SqlTypeConverter: + """ + Utility class for converting SQL types to Python types. + Based on the JDBC ConverterHelper implementation. + """ + + # SQL type to conversion function mapping + TYPE_MAPPING: Dict[str, Callable] = { + # Numeric types + SqlType.TINYINT: lambda v: int(v), + SqlType.SMALLINT: lambda v: int(v), + SqlType.INT: lambda v: int(v), + SqlType.INTEGER: lambda v: int(v), + SqlType.BIGINT: lambda v: int(v), + SqlType.FLOAT: lambda v: float(v), + SqlType.REAL: lambda v: float(v), + SqlType.DOUBLE: lambda v: float(v), + SqlType.DECIMAL: lambda v, p=None, s=None: ( + decimal.Decimal(v).quantize( + decimal.Decimal(f'0.{"0" * s}'), context=decimal.Context(prec=p) + ) + if p is not None and s is not None + else decimal.Decimal(v) + ), + SqlType.NUMERIC: lambda v, p=None, s=None: ( + decimal.Decimal(v).quantize( + decimal.Decimal(f'0.{"0" * s}'), context=decimal.Context(prec=p) + ) + if p is not None and s is not None + else decimal.Decimal(v) + ), + # Boolean types + SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), + SqlType.BIT: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), + # Date/Time types + SqlType.DATE: lambda v: datetime.date.fromisoformat(v), + SqlType.TIME: lambda v: datetime.time.fromisoformat(v), + SqlType.TIMESTAMP: lambda v: parser.parse(v), + SqlType.TIMESTAMP_NTZ: lambda v: parser.parse(v).replace(tzinfo=None), + SqlType.TIMESTAMP_LTZ: lambda v: parser.parse(v).astimezone(tz=None), + SqlType.TIMESTAMP_TZ: lambda v: parser.parse(v), + # String types - no conversion needed + SqlType.CHAR: lambda v: v, + SqlType.VARCHAR: lambda v: v, + SqlType.STRING: lambda v: v, + SqlType.TEXT: lambda v: v, + # Binary types + SqlType.BINARY: lambda v: bytes.fromhex(v), + SqlType.VARBINARY: lambda v: bytes.fromhex(v), + } + + @staticmethod + def convert_value( + value: Any, + sql_type: str, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> Any: + """ + Convert a string value to the appropriate Python type based on SQL type. + + Args: + value: The string value to convert + sql_type: The SQL type (e.g., 'int', 'decimal') + precision: Optional precision for decimal types + scale: Optional scale for decimal types + + Returns: + The converted value in the appropriate Python type + """ + if value is None: + return None + + # Normalize SQL type + sql_type = sql_type.lower().strip() + + # Handle primitive types using the mapping + if sql_type not in SqlTypeConverter.TYPE_MAPPING: + return value + + converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type] + try: + if sql_type in (SqlType.DECIMAL, SqlType.NUMERIC): + return converter_func(value, precision, scale) + else: + return converter_func(value) + except (ValueError, TypeError, decimal.InvalidOperation) as e: + logger.warning(f"Error converting value '{value}' to {sql_type}: {e}") + return value diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index c67e9b3f2..956742cd0 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -6,6 +6,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.conversion import SqlTypeConverter try: import pyarrow @@ -503,17 +504,44 @@ def __init__( def _convert_json_table(self, rows): """ Convert raw data rows to Row objects with named columns based on description. + Also converts string values to appropriate Python types based on column metadata. + Args: rows: List of raw data rows Returns: - List of Row objects with named columns + List of Row objects with named columns and converted values """ if not self.description or not rows: return rows column_names = [col[0] for col in self.description] ResultRow = Row(*column_names) - return [ResultRow(*row) for row in rows] + + # JSON + INLINE gives us string values, so we convert them to appropriate + # types based on column metadata + converted_rows = [] + for row in rows: + converted_row = [] + + for i, value in enumerate(row): + column_type = self.description[i][1] + precision = self.description[i][4] + scale = self.description[i][5] + + try: + converted_value = SqlTypeConverter.convert_value( + value, column_type, precision=precision, scale=scale + ) + converted_row.append(converted_value) + except Exception as e: + logger.warning( + f"Error converting value '{value}' to {column_type}: {e}" + ) + converted_row.append(value) + + converted_rows.append(ResultRow(*converted_row)) + + return converted_rows def fetchmany_json(self, size: int): """ diff --git a/tests/unit/test_type_conversion.py b/tests/unit/test_type_conversion.py new file mode 100644 index 000000000..9b2735657 --- /dev/null +++ b/tests/unit/test_type_conversion.py @@ -0,0 +1,161 @@ +""" +Unit tests for the type conversion utilities. +""" + +import unittest +from datetime import date, datetime, time +from decimal import Decimal + +from databricks.sql.conversion import SqlType, SqlTypeConverter + + +class TestSqlType(unittest.TestCase): + """Tests for the SqlType class.""" + + def test_is_numeric(self): + """Test the is_numeric method.""" + self.assertTrue(SqlType.is_numeric(SqlType.INT)) + self.assertTrue(SqlType.is_numeric(SqlType.TINYINT)) + self.assertTrue(SqlType.is_numeric(SqlType.SMALLINT)) + self.assertTrue(SqlType.is_numeric(SqlType.BIGINT)) + self.assertTrue(SqlType.is_numeric(SqlType.FLOAT)) + self.assertTrue(SqlType.is_numeric(SqlType.DOUBLE)) + self.assertTrue(SqlType.is_numeric(SqlType.DECIMAL)) + self.assertTrue(SqlType.is_numeric(SqlType.NUMERIC)) + self.assertFalse(SqlType.is_numeric(SqlType.BOOLEAN)) + self.assertFalse(SqlType.is_numeric(SqlType.STRING)) + self.assertFalse(SqlType.is_numeric(SqlType.DATE)) + + def test_is_boolean(self): + """Test the is_boolean method.""" + self.assertTrue(SqlType.is_boolean(SqlType.BOOLEAN)) + self.assertTrue(SqlType.is_boolean(SqlType.BIT)) + self.assertFalse(SqlType.is_boolean(SqlType.INT)) + self.assertFalse(SqlType.is_boolean(SqlType.STRING)) + + def test_is_datetime(self): + """Test the is_datetime method.""" + self.assertTrue(SqlType.is_datetime(SqlType.DATE)) + self.assertTrue(SqlType.is_datetime(SqlType.TIME)) + self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP)) + self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_NTZ)) + self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_LTZ)) + self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_TZ)) + self.assertFalse(SqlType.is_datetime(SqlType.INT)) + self.assertFalse(SqlType.is_datetime(SqlType.STRING)) + + def test_is_string(self): + """Test the is_string method.""" + self.assertTrue(SqlType.is_string(SqlType.CHAR)) + self.assertTrue(SqlType.is_string(SqlType.VARCHAR)) + self.assertTrue(SqlType.is_string(SqlType.STRING)) + self.assertTrue(SqlType.is_string(SqlType.TEXT)) + self.assertFalse(SqlType.is_string(SqlType.INT)) + self.assertFalse(SqlType.is_string(SqlType.DATE)) + + def test_is_binary(self): + """Test the is_binary method.""" + self.assertTrue(SqlType.is_binary(SqlType.BINARY)) + self.assertTrue(SqlType.is_binary(SqlType.VARBINARY)) + self.assertFalse(SqlType.is_binary(SqlType.INT)) + self.assertFalse(SqlType.is_binary(SqlType.STRING)) + + def test_is_complex(self): + """Test the is_complex method.""" + self.assertTrue(SqlType.is_complex("array")) + self.assertTrue(SqlType.is_complex("map")) + self.assertTrue(SqlType.is_complex("struct")) + self.assertFalse(SqlType.is_complex(SqlType.INT)) + self.assertFalse(SqlType.is_complex(SqlType.STRING)) + + +class TestSqlTypeConverter(unittest.TestCase): + """Tests for the SqlTypeConverter class.""" + + def test_numeric_conversions(self): + """Test numeric type conversions.""" + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.INT), 123) + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.TINYINT), 123) + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.SMALLINT), 123) + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.BIGINT), 123) + self.assertEqual( + SqlTypeConverter.convert_value("123.45", SqlType.FLOAT), 123.45 + ) + self.assertEqual( + SqlTypeConverter.convert_value("123.45", SqlType.DOUBLE), 123.45 + ) + self.assertEqual( + SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL), Decimal("123.45") + ) + + # Test decimal with precision and scale + self.assertEqual( + SqlTypeConverter.convert_value( + "123.456", SqlType.DECIMAL, precision=5, scale=2 + ), + Decimal("123.46"), # Rounded to scale 2 + ) + + def test_boolean_conversions(self): + """Test boolean type conversions.""" + self.assertTrue(SqlTypeConverter.convert_value("true", SqlType.BOOLEAN)) + self.assertTrue(SqlTypeConverter.convert_value("TRUE", SqlType.BOOLEAN)) + self.assertTrue(SqlTypeConverter.convert_value("1", SqlType.BOOLEAN)) + self.assertTrue(SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN)) + self.assertFalse(SqlTypeConverter.convert_value("false", SqlType.BOOLEAN)) + self.assertFalse(SqlTypeConverter.convert_value("FALSE", SqlType.BOOLEAN)) + self.assertFalse(SqlTypeConverter.convert_value("0", SqlType.BOOLEAN)) + self.assertFalse(SqlTypeConverter.convert_value("no", SqlType.BOOLEAN)) + + def test_datetime_conversions(self): + """Test date/time type conversions.""" + self.assertEqual( + SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE), + date(2023, 1, 15), + ) + self.assertEqual( + SqlTypeConverter.convert_value("14:30:45", SqlType.TIME), time(14, 30, 45) + ) + self.assertEqual( + SqlTypeConverter.convert_value("2023-01-15 14:30:45", SqlType.TIMESTAMP), + datetime(2023, 1, 15, 14, 30, 45), + ) + + def test_string_conversions(self): + """Test string type conversions.""" + self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.STRING), "test") + self.assertEqual( + SqlTypeConverter.convert_value("test", SqlType.VARCHAR), "test" + ) + self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.CHAR), "test") + self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.TEXT), "test") + + def test_error_handling(self): + """Test error handling in conversions.""" + # Test invalid conversions - should return original value + self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.INT), "abc") + self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.FLOAT), "abc") + self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.DECIMAL), "abc") + + def test_null_handling(self): + """Test handling of NULL values.""" + self.assertIsNone(SqlTypeConverter.convert_value(None, SqlType.INT)) + self.assertIsNone(SqlTypeConverter.convert_value(None, SqlType.STRING)) + self.assertIsNone(SqlTypeConverter.convert_value(None, SqlType.DATE)) + + def test_complex_type_handling(self): + """Test handling of complex types.""" + # Complex types should be returned as-is for now + self.assertEqual( + SqlTypeConverter.convert_value('{"a": 1}', "array"), '{"a": 1}' + ) + self.assertEqual( + SqlTypeConverter.convert_value('{"a": 1}', "map"), '{"a": 1}' + ) + self.assertEqual( + SqlTypeConverter.convert_value('{"a": 1}', "struct"), '{"a": 1}' + ) + + +if __name__ == "__main__": + unittest.main() From 9f0f969360efe2fa0078e10124aa3712adb8bf21 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 03:33:07 +0000 Subject: [PATCH 156/204] remove SEA running on metadata queries (known failures Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 44 ---------------------------------------- 1 file changed, 44 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 44d507ff9..d31ba9b8e 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -352,17 +352,6 @@ def test_create_table_will_return_empty_result_set(self, extra_params): finally: cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) def test_get_tables(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -409,17 +398,6 @@ def test_get_tables(self, extra_params): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) def test_get_columns(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -543,17 +521,6 @@ def test_escape_single_quotes(self, extra_params): rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) def test_get_schemas(self, extra_params): with self.cursor(extra_params) as cursor: database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -572,17 +539,6 @@ def test_get_schemas(self, extra_params): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) def test_get_catalogs(self, extra_params): with self.cursor(extra_params) as cursor: cursor.catalogs() From 04a1936627a9d1a255ed0b1527f94f31e5981639 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 03:36:56 +0000 Subject: [PATCH 157/204] remove un-necessary docstrings Signed-off-by: varun-edachali-dbx --- src/databricks/sql/conversion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/databricks/sql/conversion.py b/src/databricks/sql/conversion.py index f6f98242f..602378f41 100644 --- a/src/databricks/sql/conversion.py +++ b/src/databricks/sql/conversion.py @@ -183,10 +183,8 @@ def convert_value( if value is None: return None - # Normalize SQL type sql_type = sql_type.lower().strip() - # Handle primitive types using the mapping if sql_type not in SqlTypeConverter.TYPE_MAPPING: return value From 278b8cd5d076d5a9d8e705e754e48a1c93e3bb44 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 05:14:29 +0000 Subject: [PATCH 158/204] align expected types with databricks sdk Signed-off-by: varun-edachali-dbx --- .../sql/{ => backend/sea}/conversion.py | 93 +++++++------------ src/databricks/sql/result_set.py | 2 +- ..._type_conversion.py => test_conversion.py} | 43 ++++----- 3 files changed, 56 insertions(+), 82 deletions(-) rename src/databricks/sql/{ => backend/sea}/conversion.py (65%) rename tests/unit/{test_type_conversion.py => test_conversion.py} (83%) diff --git a/src/databricks/sql/conversion.py b/src/databricks/sql/backend/sea/conversion.py similarity index 65% rename from src/databricks/sql/conversion.py rename to src/databricks/sql/backend/sea/conversion.py index 602378f41..fe263dce3 100644 --- a/src/databricks/sql/conversion.py +++ b/src/databricks/sql/backend/sea/conversion.py @@ -15,89 +15,75 @@ class SqlType: - """SQL type constants for improved maintainability.""" + """ + SQL type constants + """ # Numeric types - TINYINT = "tinyint" - SMALLINT = "smallint" + BYTE = "byte" + SHORT = "short" INT = "int" - INTEGER = "integer" - BIGINT = "bigint" + LONG = "long" FLOAT = "float" - REAL = "real" DOUBLE = "double" DECIMAL = "decimal" - NUMERIC = "numeric" - # Boolean types + # Boolean type BOOLEAN = "boolean" - BIT = "bit" # Date/Time types DATE = "date" - TIME = "time" TIMESTAMP = "timestamp" - TIMESTAMP_NTZ = "timestamp_ntz" - TIMESTAMP_LTZ = "timestamp_ltz" - TIMESTAMP_TZ = "timestamp_tz" + INTERVAL = "interval" # String types CHAR = "char" - VARCHAR = "varchar" STRING = "string" - TEXT = "text" - # Binary types + # Binary type BINARY = "binary" - VARBINARY = "varbinary" # Complex types ARRAY = "array" MAP = "map" STRUCT = "struct" + # Other types + NULL = "null" + USER_DEFINED_TYPE = "user_defined_type" + @classmethod def is_numeric(cls, sql_type: str) -> bool: """Check if the SQL type is a numeric type.""" return sql_type.lower() in ( - cls.TINYINT, - cls.SMALLINT, + cls.BYTE, + cls.SHORT, cls.INT, - cls.INTEGER, - cls.BIGINT, + cls.LONG, cls.FLOAT, - cls.REAL, cls.DOUBLE, cls.DECIMAL, - cls.NUMERIC, ) @classmethod def is_boolean(cls, sql_type: str) -> bool: """Check if the SQL type is a boolean type.""" - return sql_type.lower() in (cls.BOOLEAN, cls.BIT) + return sql_type.lower() == cls.BOOLEAN @classmethod def is_datetime(cls, sql_type: str) -> bool: """Check if the SQL type is a date/time type.""" - return sql_type.lower() in ( - cls.DATE, - cls.TIME, - cls.TIMESTAMP, - cls.TIMESTAMP_NTZ, - cls.TIMESTAMP_LTZ, - cls.TIMESTAMP_TZ, - ) + return sql_type.lower() in (cls.DATE, cls.TIMESTAMP, cls.INTERVAL) @classmethod def is_string(cls, sql_type: str) -> bool: """Check if the SQL type is a string type.""" - return sql_type.lower() in (cls.CHAR, cls.VARCHAR, cls.STRING, cls.TEXT) + return sql_type.lower() in (cls.CHAR, cls.STRING) @classmethod def is_binary(cls, sql_type: str) -> bool: """Check if the SQL type is a binary type.""" - return sql_type.lower() in (cls.BINARY, cls.VARBINARY) + return sql_type.lower() == cls.BINARY @classmethod def is_complex(cls, sql_type: str) -> bool: @@ -107,25 +93,25 @@ def is_complex(cls, sql_type: str) -> bool: sql_type.startswith(cls.ARRAY) or sql_type.startswith(cls.MAP) or sql_type.startswith(cls.STRUCT) + or sql_type == cls.USER_DEFINED_TYPE ) class SqlTypeConverter: """ Utility class for converting SQL types to Python types. - Based on the JDBC ConverterHelper implementation. + Based on the types supported by the Databricks SDK. """ # SQL type to conversion function mapping + # TODO: complex types TYPE_MAPPING: Dict[str, Callable] = { # Numeric types - SqlType.TINYINT: lambda v: int(v), - SqlType.SMALLINT: lambda v: int(v), + SqlType.BYTE: lambda v: int(v), + SqlType.SHORT: lambda v: int(v), SqlType.INT: lambda v: int(v), - SqlType.INTEGER: lambda v: int(v), - SqlType.BIGINT: lambda v: int(v), + SqlType.LONG: lambda v: int(v), SqlType.FLOAT: lambda v: float(v), - SqlType.REAL: lambda v: float(v), SqlType.DOUBLE: lambda v: float(v), SqlType.DECIMAL: lambda v, p=None, s=None: ( decimal.Decimal(v).quantize( @@ -134,31 +120,21 @@ class SqlTypeConverter: if p is not None and s is not None else decimal.Decimal(v) ), - SqlType.NUMERIC: lambda v, p=None, s=None: ( - decimal.Decimal(v).quantize( - decimal.Decimal(f'0.{"0" * s}'), context=decimal.Context(prec=p) - ) - if p is not None and s is not None - else decimal.Decimal(v) - ), - # Boolean types + # Boolean type SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), - SqlType.BIT: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), # Date/Time types SqlType.DATE: lambda v: datetime.date.fromisoformat(v), - SqlType.TIME: lambda v: datetime.time.fromisoformat(v), SqlType.TIMESTAMP: lambda v: parser.parse(v), - SqlType.TIMESTAMP_NTZ: lambda v: parser.parse(v).replace(tzinfo=None), - SqlType.TIMESTAMP_LTZ: lambda v: parser.parse(v).astimezone(tz=None), - SqlType.TIMESTAMP_TZ: lambda v: parser.parse(v), + SqlType.INTERVAL: lambda v: v, # Keep as string for now # String types - no conversion needed SqlType.CHAR: lambda v: v, - SqlType.VARCHAR: lambda v: v, SqlType.STRING: lambda v: v, - SqlType.TEXT: lambda v: v, - # Binary types + # Binary type SqlType.BINARY: lambda v: bytes.fromhex(v), - SqlType.VARBINARY: lambda v: bytes.fromhex(v), + # Other types + SqlType.NULL: lambda v: None, + # Complex types and user-defined types return as-is + SqlType.USER_DEFINED_TYPE: lambda v: v, } @staticmethod @@ -180,6 +156,7 @@ def convert_value( Returns: The converted value in the appropriate Python type """ + if value is None: return None @@ -190,7 +167,7 @@ def convert_value( converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type] try: - if sql_type in (SqlType.DECIMAL, SqlType.NUMERIC): + if sql_type == SqlType.DECIMAL: return converter_func(value, precision, scale) else: return converter_func(value) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 956742cd0..d734db5c6 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -6,7 +6,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest -from databricks.sql.conversion import SqlTypeConverter +from databricks.sql.backend.sea.conversion import SqlTypeConverter try: import pyarrow diff --git a/tests/unit/test_type_conversion.py b/tests/unit/test_conversion.py similarity index 83% rename from tests/unit/test_type_conversion.py rename to tests/unit/test_conversion.py index 9b2735657..656e6730a 100644 --- a/tests/unit/test_type_conversion.py +++ b/tests/unit/test_conversion.py @@ -6,7 +6,7 @@ from datetime import date, datetime, time from decimal import Decimal -from databricks.sql.conversion import SqlType, SqlTypeConverter +from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter class TestSqlType(unittest.TestCase): @@ -15,13 +15,12 @@ class TestSqlType(unittest.TestCase): def test_is_numeric(self): """Test the is_numeric method.""" self.assertTrue(SqlType.is_numeric(SqlType.INT)) - self.assertTrue(SqlType.is_numeric(SqlType.TINYINT)) - self.assertTrue(SqlType.is_numeric(SqlType.SMALLINT)) - self.assertTrue(SqlType.is_numeric(SqlType.BIGINT)) + self.assertTrue(SqlType.is_numeric(SqlType.BYTE)) + self.assertTrue(SqlType.is_numeric(SqlType.SHORT)) + self.assertTrue(SqlType.is_numeric(SqlType.LONG)) self.assertTrue(SqlType.is_numeric(SqlType.FLOAT)) self.assertTrue(SqlType.is_numeric(SqlType.DOUBLE)) self.assertTrue(SqlType.is_numeric(SqlType.DECIMAL)) - self.assertTrue(SqlType.is_numeric(SqlType.NUMERIC)) self.assertFalse(SqlType.is_numeric(SqlType.BOOLEAN)) self.assertFalse(SqlType.is_numeric(SqlType.STRING)) self.assertFalse(SqlType.is_numeric(SqlType.DATE)) @@ -29,34 +28,27 @@ def test_is_numeric(self): def test_is_boolean(self): """Test the is_boolean method.""" self.assertTrue(SqlType.is_boolean(SqlType.BOOLEAN)) - self.assertTrue(SqlType.is_boolean(SqlType.BIT)) self.assertFalse(SqlType.is_boolean(SqlType.INT)) self.assertFalse(SqlType.is_boolean(SqlType.STRING)) def test_is_datetime(self): """Test the is_datetime method.""" self.assertTrue(SqlType.is_datetime(SqlType.DATE)) - self.assertTrue(SqlType.is_datetime(SqlType.TIME)) self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP)) - self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_NTZ)) - self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_LTZ)) - self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP_TZ)) + self.assertTrue(SqlType.is_datetime(SqlType.INTERVAL)) self.assertFalse(SqlType.is_datetime(SqlType.INT)) self.assertFalse(SqlType.is_datetime(SqlType.STRING)) def test_is_string(self): """Test the is_string method.""" self.assertTrue(SqlType.is_string(SqlType.CHAR)) - self.assertTrue(SqlType.is_string(SqlType.VARCHAR)) self.assertTrue(SqlType.is_string(SqlType.STRING)) - self.assertTrue(SqlType.is_string(SqlType.TEXT)) self.assertFalse(SqlType.is_string(SqlType.INT)) self.assertFalse(SqlType.is_string(SqlType.DATE)) def test_is_binary(self): """Test the is_binary method.""" self.assertTrue(SqlType.is_binary(SqlType.BINARY)) - self.assertTrue(SqlType.is_binary(SqlType.VARBINARY)) self.assertFalse(SqlType.is_binary(SqlType.INT)) self.assertFalse(SqlType.is_binary(SqlType.STRING)) @@ -75,9 +67,9 @@ class TestSqlTypeConverter(unittest.TestCase): def test_numeric_conversions(self): """Test numeric type conversions.""" self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.INT), 123) - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.TINYINT), 123) - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.SMALLINT), 123) - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.BIGINT), 123) + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.BYTE), 123) + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.SHORT), 123) + self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.LONG), 123) self.assertEqual( SqlTypeConverter.convert_value("123.45", SqlType.FLOAT), 123.45 ) @@ -113,9 +105,6 @@ def test_datetime_conversions(self): SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE), date(2023, 1, 15), ) - self.assertEqual( - SqlTypeConverter.convert_value("14:30:45", SqlType.TIME), time(14, 30, 45) - ) self.assertEqual( SqlTypeConverter.convert_value("2023-01-15 14:30:45", SqlType.TIMESTAMP), datetime(2023, 1, 15, 14, 30, 45), @@ -124,15 +113,19 @@ def test_datetime_conversions(self): def test_string_conversions(self): """Test string type conversions.""" self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.STRING), "test") + self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.CHAR), "test") + + def test_binary_conversions(self): + """Test binary type conversions.""" + hex_str = "68656c6c6f" # "hello" in hex + expected_bytes = b"hello" + self.assertEqual( - SqlTypeConverter.convert_value("test", SqlType.VARCHAR), "test" + SqlTypeConverter.convert_value(hex_str, SqlType.BINARY), expected_bytes ) - self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.CHAR), "test") - self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.TEXT), "test") def test_error_handling(self): """Test error handling in conversions.""" - # Test invalid conversions - should return original value self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.INT), "abc") self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.FLOAT), "abc") self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.DECIMAL), "abc") @@ -155,6 +148,10 @@ def test_complex_type_handling(self): self.assertEqual( SqlTypeConverter.convert_value('{"a": 1}', "struct"), '{"a": 1}' ) + self.assertEqual( + SqlTypeConverter.convert_value('{"a": 1}', SqlType.USER_DEFINED_TYPE), + '{"a": 1}', + ) if __name__ == "__main__": From 91b7f7f9fa374b1fff2275e16bb5c370d01e22e8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 05:36:54 +0000 Subject: [PATCH 159/204] link rest api reference to validate types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/conversion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/databricks/sql/backend/sea/conversion.py b/src/databricks/sql/backend/sea/conversion.py index fe263dce3..019305bf0 100644 --- a/src/databricks/sql/backend/sea/conversion.py +++ b/src/databricks/sql/backend/sea/conversion.py @@ -17,6 +17,9 @@ class SqlType: """ SQL type constants + + The list of types can be found in the SEA REST API Reference: + https://docs.databricks.com/api/workspace/statementexecution/executestatement """ # Numeric types From 7a5ae1366218572be9b8a495c2e8f4948d844153 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 05:43:46 +0000 Subject: [PATCH 160/204] remove test_catalogs_returns_arrow_table test metadata commands not expected to pass Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index d31ba9b8e..18a7be965 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -935,19 +935,8 @@ def test_decimal_not_returned_as_strings_arrow(self): assert pyarrow.types.is_decimal(decimal_type) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_catalogs_returns_arrow_table(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_catalogs_returns_arrow_table(self): + with self.cursor() as cursor: cursor.catalogs() results = cursor.fetchall_arrow() assert isinstance(results, pyarrow.Table) From f1776f3e333649784779e21a828ce46636dc7172 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:17:01 +0000 Subject: [PATCH 161/204] fix fetchall_arrow and fetchmany_arrow Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 53 ++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index d734db5c6..b8bdd3935 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -501,22 +501,25 @@ def __init__( # Initialize queue for result data if not provided self.results = results_queue or JsonQueue([]) - def _convert_json_table(self, rows): + def _convert_json_to_arrow(self, rows): + """ + Convert raw data rows to Arrow table. + """ + columns = [] + num_cols = len(rows[0]) + for i in range(num_cols): + columns.append([row[i] for row in rows]) + names = [col[0] for col in self.description] + return pyarrow.Table.from_arrays(columns, names=names) + + def _convert_json_types(self, rows): """ Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. - - Args: - rows: List of raw data rows - Returns: - List of Row objects with named columns and converted values """ if not self.description or not rows: return rows - column_names = [col[0] for col in self.description] - ResultRow = Row(*column_names) - # JSON + INLINE gives us string values, so we convert them to appropriate # types based on column metadata converted_rows = [] @@ -539,10 +542,28 @@ def _convert_json_table(self, rows): ) converted_row.append(value) - converted_rows.append(ResultRow(*converted_row)) + converted_rows.append(converted_row) return converted_rows + def _convert_json_table(self, rows): + """ + Convert raw data rows to Row objects with named columns based on description. + Also converts string values to appropriate Python types based on column metadata. + + Args: + rows: List of raw data rows + Returns: + List of Row objects with named columns and converted values + """ + if not self.description or not rows: + return rows + + ResultRow = Row(*[col[0] for col in self.description]) + rows = self._convert_json_types(rows) + + return [ResultRow(*row) for row in rows] + def fetchmany_json(self, size: int): """ Fetch the next set of rows as a columnar table. @@ -593,7 +614,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - results = self.results.next_n_rows(size) + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchmany_arrow only supported for JSON data") + + rows = self._convert_json_types(self.results.next_n_rows(size)) + results = self._convert_json_to_arrow(rows) self._next_row_index += results.num_rows return results @@ -602,7 +627,11 @@ def fetchall_arrow(self) -> "pyarrow.Table": """ Fetch all remaining rows as an Arrow table. """ - results = self.results.remaining_rows() + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchall_arrow only supported for JSON data") + + rows = self._convert_json_types(self.results.remaining_rows()) + results = self._convert_json_to_arrow(rows) self._next_row_index += results.num_rows return results From 61433312bdc8254b814b41f512dd4f8c49890aa3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:19:03 +0000 Subject: [PATCH 162/204] remove thrift aligned test_cancel_during_execute from SEA tests Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 18a7be965..9d0d0141e 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -590,19 +590,8 @@ def test_unicode(self, extra_params): assert len(results) == 1 and len(results[0]) == 1 assert results[0][0] == unicode_str - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_cancel_during_execute(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_cancel_during_execute(self): + with self.cursor() as cursor: def execute_really_long_query(): cursor.execute( From 5eaded4ccc358bf8be7551e2d46876eaff363c5c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:22:12 +0000 Subject: [PATCH 163/204] remove un-necessary changes in example scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_async_query.py | 4 ++-- examples/experimental/tests/test_sea_sync_query.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index f805834b4..a2c27323f 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -14,7 +14,7 @@ def test_sea_async_query_with_cloud_fetch(): """ - Test executing a query asynchronously using the SEA backend with cloud fetch enabled. + Test executing a simple query asynchronously using the SEA backend with cloud fetch enabled. This function connects to a Databricks SQL endpoint using the SEA backend, executes a query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. @@ -120,7 +120,7 @@ def test_sea_async_query_with_cloud_fetch(): def test_sea_async_query_without_cloud_fetch(): """ - Test executing a query asynchronously using the SEA backend with cloud fetch disabled. + Test executing a simple query asynchronously using the SEA backend with cloud fetch disabled. This function connects to a Databricks SQL endpoint using the SEA backend, executes a query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 9566da5cd..ba9272adf 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -12,7 +12,7 @@ def test_sea_sync_query_with_cloud_fetch(): """ - Test executing a query synchronously using the SEA backend with cloud fetch enabled. + Test executing a simple query synchronously using the SEA backend with cloud fetch enabled. This function connects to a Databricks SQL endpoint using the SEA backend, executes a query with cloud fetch enabled, and verifies that execution completes successfully. @@ -90,7 +90,7 @@ def test_sea_sync_query_with_cloud_fetch(): def test_sea_sync_query_without_cloud_fetch(): """ - Test executing a query synchronously using the SEA backend with cloud fetch disabled. + Test executing a simple query synchronously using the SEA backend with cloud fetch disabled. This function connects to a Databricks SQL endpoint using the SEA backend, executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. From eeed9a156871532a1f194ca87e5c9f9597c0eb92 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:23:31 +0000 Subject: [PATCH 164/204] remove un-necessary chagnes in example scripts Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_async_query.py | 8 ++++---- examples/experimental/tests/test_sea_sync_query.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index a2c27323f..1685ac4ca 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -14,10 +14,10 @@ def test_sea_async_query_with_cloud_fetch(): """ - Test executing a simple query asynchronously using the SEA backend with cloud fetch enabled. + Test executing a query asynchronously using the SEA backend with cloud fetch enabled. This function connects to a Databricks SQL endpoint using the SEA backend, - executes a query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. + executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully. """ server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") @@ -120,10 +120,10 @@ def test_sea_async_query_with_cloud_fetch(): def test_sea_async_query_without_cloud_fetch(): """ - Test executing a simple query asynchronously using the SEA backend with cloud fetch disabled. + Test executing a query asynchronously using the SEA backend with cloud fetch disabled. This function connects to a Databricks SQL endpoint using the SEA backend, - executes a query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. + executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully. """ server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index ba9272adf..76941e2d2 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -12,10 +12,10 @@ def test_sea_sync_query_with_cloud_fetch(): """ - Test executing a simple query synchronously using the SEA backend with cloud fetch enabled. + Test executing a query synchronously using the SEA backend with cloud fetch enabled. This function connects to a Databricks SQL endpoint using the SEA backend, - executes a query with cloud fetch enabled, and verifies that execution completes successfully. + executes a simple query with cloud fetch enabled, and verifies that execution completes successfully. """ server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") @@ -90,7 +90,7 @@ def test_sea_sync_query_with_cloud_fetch(): def test_sea_sync_query_without_cloud_fetch(): """ - Test executing a simple query synchronously using the SEA backend with cloud fetch disabled. + Test executing a query synchronously using the SEA backend with cloud fetch disabled. This function connects to a Databricks SQL endpoint using the SEA backend, executes a simple query with cloud fetch disabled, and verifies that execution completes successfully. From f23388631504d91c1cb0fe84e76d7419cb9d746b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:27:02 +0000 Subject: [PATCH 165/204] _convert_json_table -> _create_json_table Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index b8bdd3935..95c9c4823 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -471,9 +471,9 @@ def __init__( manifest: Manifest from SEA response (optional) """ - results_queue = None + self.results = None if result_data: - results_queue = SeaResultSetQueueFactory.build_queue( + self.results = SeaResultSetQueueFactory.build_queue( result_data, manifest, str(execute_response.command_id.to_sea_statement_id()), @@ -498,9 +498,6 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", ) - # Initialize queue for result data if not provided - self.results = results_queue or JsonQueue([]) - def _convert_json_to_arrow(self, rows): """ Convert raw data rows to Arrow table. @@ -546,7 +543,7 @@ def _convert_json_types(self, rows): return converted_rows - def _convert_json_table(self, rows): + def _create_json_table(self, rows): """ Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. @@ -645,7 +642,7 @@ def fetchone(self) -> Optional[Row]: A single Row object or None if no more rows are available """ if isinstance(self.results, JsonQueue): - res = self._convert_json_table(self.fetchmany_json(1)) + res = self._create_json_table(self.fetchmany_json(1)) else: raise NotImplementedError("fetchone only supported for JSON data") @@ -665,7 +662,7 @@ def fetchmany(self, size: int) -> List[Row]: ValueError: If size is negative """ if isinstance(self.results, JsonQueue): - return self._convert_json_table(self.fetchmany_json(size)) + return self._create_json_table(self.fetchmany_json(size)) else: raise NotImplementedError("fetchmany only supported for JSON data") @@ -677,6 +674,6 @@ def fetchall(self) -> List[Row]: List of Row objects containing all remaining rows """ if isinstance(self.results, JsonQueue): - return self._convert_json_table(self.fetchall_json()) + return self._create_json_table(self.fetchall_json()) else: raise NotImplementedError("fetchall only supported for JSON data") From 68ac4374b287caa7b87295d2d44dd01876adcb7a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:29:09 +0000 Subject: [PATCH 166/204] remove accidentally removed test Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index e16aa5008..b79aaa093 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -637,6 +637,31 @@ def test_utility_methods(self, sea_client): sea_client._extract_description_from_manifest(no_columns_manifest) is None ) + def test_results_message_to_execute_response_is_staging_operation(self, sea_client): + """Test that is_staging_operation is correctly set from manifest.is_volume_operation.""" + # Test when is_volume_operation is True + response = MagicMock() + response.statement_id = "test-statement-123" + response.status.state = CommandState.SUCCEEDED + response.manifest.is_volume_operation = True + response.manifest.result_compression = "NONE" + response.manifest.format = "JSON_ARRAY" + + # Mock the _extract_description_from_manifest method to return None + with patch.object( + sea_client, "_extract_description_from_manifest", return_value=None + ): + result = sea_client._results_message_to_execute_response(response) + assert result.is_staging_operation is True + + # Test when is_volume_operation is False + response.manifest.is_volume_operation = False + with patch.object( + sea_client, "_extract_description_from_manifest", return_value=None + ): + result = sea_client._results_message_to_execute_response(response) + assert result.is_staging_operation is False + def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): """Test the get_catalogs method.""" # Mock the execute_command method From 7fd0845afa480c31c6979038376cefc0f3d8bfe4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:32:03 +0000 Subject: [PATCH 167/204] remove new unit tests (to be re-added based on new arch) Signed-off-by: varun-edachali-dbx --- tests/unit/test_json_queue.py | 137 ------------------ .../unit/test_sea_result_set_queue_factory.py | 87 ----------- 2 files changed, 224 deletions(-) delete mode 100644 tests/unit/test_json_queue.py delete mode 100644 tests/unit/test_sea_result_set_queue_factory.py diff --git a/tests/unit/test_json_queue.py b/tests/unit/test_json_queue.py deleted file mode 100644 index ee19a574f..000000000 --- a/tests/unit/test_json_queue.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -Tests for the JsonQueue class. - -This module contains tests for the JsonQueue class, which implements -a queue for JSON array data returned by the SEA backend. -""" - -import pytest -from databricks.sql.utils import JsonQueue - - -class TestJsonQueue: - """Test suite for the JsonQueue class.""" - - @pytest.fixture - def sample_data_array(self): - """Create a sample data array for testing.""" - return [ - [1, "value1"], - [2, "value2"], - [3, "value3"], - [4, "value4"], - [5, "value5"], - ] - - def test_init(self, sample_data_array): - """Test initializing JsonQueue with a data array.""" - queue = JsonQueue(sample_data_array) - assert queue.data_array == sample_data_array - assert queue.cur_row_index == 0 - assert queue.n_valid_rows == 5 - - def test_next_n_rows_partial(self, sample_data_array): - """Test getting a subset of rows.""" - queue = JsonQueue(sample_data_array) - rows = queue.next_n_rows(3) - - # Check that we got the first 3 rows - assert rows == sample_data_array[:3] - - # Check that the current row index was updated - assert queue.cur_row_index == 3 - - def test_next_n_rows_all(self, sample_data_array): - """Test getting all rows at once.""" - queue = JsonQueue(sample_data_array) - rows = queue.next_n_rows(10) # More than available - - # Check that we got all rows - assert rows == sample_data_array - - # Check that the current row index was updated - assert queue.cur_row_index == 5 - - def test_next_n_rows_empty(self): - """Test getting rows from an empty queue.""" - queue = JsonQueue([]) - rows = queue.next_n_rows(5) - - # Check that we got an empty list - assert rows == [] - - # Check that the current row index was not updated - assert queue.cur_row_index == 0 - - def test_next_n_rows_zero(self, sample_data_array): - """Test getting zero rows.""" - queue = JsonQueue(sample_data_array) - rows = queue.next_n_rows(0) - - # Check that we got an empty list - assert rows == [] - - # Check that the current row index was not updated - assert queue.cur_row_index == 0 - - def test_next_n_rows_sequential(self, sample_data_array): - """Test getting rows in multiple sequential calls.""" - queue = JsonQueue(sample_data_array) - - # Get first 2 rows - rows1 = queue.next_n_rows(2) - assert rows1 == sample_data_array[:2] - assert queue.cur_row_index == 2 - - # Get next 2 rows - rows2 = queue.next_n_rows(2) - assert rows2 == sample_data_array[2:4] - assert queue.cur_row_index == 4 - - # Get remaining rows - rows3 = queue.next_n_rows(2) - assert rows3 == sample_data_array[4:] - assert queue.cur_row_index == 5 - - def test_remaining_rows(self, sample_data_array): - """Test getting all remaining rows.""" - queue = JsonQueue(sample_data_array) - - # Get first 2 rows - queue.next_n_rows(2) - - # Get remaining rows - rows = queue.remaining_rows() - - # Check that we got the remaining rows - assert rows == sample_data_array[2:] - - # Check that the current row index was updated to the end - assert queue.cur_row_index == 5 - - def test_remaining_rows_empty(self): - """Test getting remaining rows from an empty queue.""" - queue = JsonQueue([]) - rows = queue.remaining_rows() - - # Check that we got an empty list - assert rows == [] - - # Check that the current row index was not updated - assert queue.cur_row_index == 0 - - def test_remaining_rows_after_all_consumed(self, sample_data_array): - """Test getting remaining rows after all rows have been consumed.""" - queue = JsonQueue(sample_data_array) - - # Consume all rows - queue.next_n_rows(10) - - # Try to get remaining rows - rows = queue.remaining_rows() - - # Check that we got an empty list - assert rows == [] - - # Check that the current row index was not updated - assert queue.cur_row_index == 5 diff --git a/tests/unit/test_sea_result_set_queue_factory.py b/tests/unit/test_sea_result_set_queue_factory.py deleted file mode 100644 index f72510afb..000000000 --- a/tests/unit/test_sea_result_set_queue_factory.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Tests for the SeaResultSetQueueFactory class. - -This module contains tests for the SeaResultSetQueueFactory class, which builds -appropriate result set queues for the SEA backend. -""" - -import pytest -from unittest.mock import Mock, patch - -from databricks.sql.utils import SeaResultSetQueueFactory, JsonQueue -from databricks.sql.backend.sea.models.base import ResultData, ResultManifest - - -class TestSeaResultSetQueueFactory: - """Test suite for the SeaResultSetQueueFactory class.""" - - @pytest.fixture - def mock_result_data_with_json(self): - """Create a mock ResultData with JSON data.""" - result_data = Mock(spec=ResultData) - result_data.data = [[1, "value1"], [2, "value2"]] - result_data.external_links = None - return result_data - - @pytest.fixture - def mock_result_data_with_external_links(self): - """Create a mock ResultData with external links.""" - result_data = Mock(spec=ResultData) - result_data.data = None - result_data.external_links = ["link1", "link2"] - return result_data - - @pytest.fixture - def mock_result_data_empty(self): - """Create a mock ResultData with no data.""" - result_data = Mock(spec=ResultData) - result_data.data = None - result_data.external_links = None - return result_data - - @pytest.fixture - def mock_manifest(self): - """Create a mock manifest.""" - return Mock(spec=ResultManifest) - - def test_build_queue_with_json_data( - self, mock_result_data_with_json, mock_manifest - ): - """Test building a queue with JSON data.""" - queue = SeaResultSetQueueFactory.build_queue( - sea_result_data=mock_result_data_with_json, - manifest=mock_manifest, - statement_id="test-statement-id", - ) - - # Check that we got a JsonQueue - assert isinstance(queue, JsonQueue) - - # Check that the queue has the correct data - assert queue.data_array == mock_result_data_with_json.data - - def test_build_queue_with_external_links( - self, mock_result_data_with_external_links, mock_manifest - ): - """Test building a queue with external links.""" - with pytest.raises( - NotImplementedError, - match="EXTERNAL_LINKS disposition is not implemented for SEA backend", - ): - SeaResultSetQueueFactory.build_queue( - sea_result_data=mock_result_data_with_external_links, - manifest=mock_manifest, - statement_id="test-statement-id", - ) - - def test_build_queue_with_empty_data(self, mock_result_data_empty, mock_manifest): - """Test building a queue with empty data.""" - queue = SeaResultSetQueueFactory.build_queue( - sea_result_data=mock_result_data_empty, - manifest=mock_manifest, - statement_id="test-statement-id", - ) - - # Check that we got a JsonQueue with empty data - assert isinstance(queue, JsonQueue) - assert queue.data_array == [] From ea7ff73e9b664827890d4233e9b1c60f6ceb5901 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 06:32:40 +0000 Subject: [PATCH 168/204] remove changes in sea_result_set functionality (to be re-added) Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_result_set.py | 363 ++++-------------------------- 1 file changed, 48 insertions(+), 315 deletions(-) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 8c6b9ae3a..c596dbc14 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -10,8 +10,6 @@ from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.types import CommandId, CommandState, BackendType -from databricks.sql.utils import JsonQueue -from databricks.sql.types import Row class TestSeaResultSet: @@ -22,15 +20,12 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True - connection.disable_pandas = False return connection @pytest.fixture def mock_sea_client(self): """Create a mock SEA client.""" - client = Mock() - client.max_download_threads = 10 - return client + return Mock() @pytest.fixture def execute_response(self): @@ -42,27 +37,11 @@ def execute_response(self): mock_response.is_direct_results = False mock_response.results_queue = None mock_response.description = [ - ("col1", "INT", None, None, None, None, None), - ("col2", "STRING", None, None, None, None, None), + ("test_value", "INT", None, None, None, None, None) ] mock_response.is_staging_operation = False - mock_response.lz4_compressed = False - mock_response.arrow_schema_bytes = b"" return mock_response - @pytest.fixture - def mock_result_data(self): - """Create mock result data.""" - result_data = Mock() - result_data.data = [[1, "value1"], [2, "value2"], [3, "value3"]] - result_data.external_links = None - return result_data - - @pytest.fixture - def mock_manifest(self): - """Create a mock manifest.""" - return Mock() - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -84,49 +63,6 @@ def test_init_with_execute_response( assert result_set.arraysize == 100 assert result_set.description == execute_response.description - # Verify that a JsonQueue was created with empty data - assert isinstance(result_set.results, JsonQueue) - assert result_set.results.data_array == [] - - def test_init_with_result_data( - self, - mock_connection, - mock_sea_client, - execute_response, - mock_result_data, - mock_manifest, - ): - """Test initializing SeaResultSet with result data.""" - with patch( - "databricks.sql.result_set.SeaResultSetQueueFactory" - ) as mock_factory: - mock_queue = Mock(spec=JsonQueue) - mock_factory.build_queue.return_value = mock_queue - - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - manifest=mock_manifest, - ) - - # Verify that the factory was called with the correct arguments - mock_factory.build_queue.assert_called_once_with( - mock_result_data, - mock_manifest, - str(execute_response.command_id.to_sea_statement_id()), - description=execute_response.description, - max_download_threads=mock_sea_client.max_download_threads, - sea_client=mock_sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) - - # Verify that the queue was set correctly - assert result_set.results == mock_queue - def test_close(self, mock_connection, mock_sea_client, execute_response): """Test closing a result set.""" result_set = SeaResultSet( @@ -187,10 +123,10 @@ def test_close_when_connection_closed( assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_convert_json_table( + def test_unimplemented_methods( self, mock_connection, mock_sea_client, execute_response ): - """Test converting JSON data to Row objects.""" + """Test that unimplemented methods raise NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -199,142 +135,57 @@ def test_convert_json_table( arraysize=100, ) - # Sample data - data = [[1, "value1"], [2, "value2"]] - - # Convert to Row objects - rows = result_set._convert_json_table(data) - - # Check that we got Row objects with the correct values - assert len(rows) == 2 - assert isinstance(rows[0], Row) - assert rows[0].col1 == 1 - assert rows[0].col2 == "value1" - assert rows[1].col1 == 2 - assert rows[1].col2 == "value2" - - def test_convert_json_table_empty( - self, mock_connection, mock_sea_client, execute_response - ): - """Test converting empty JSON data.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Empty data - data = [] - - # Convert to Row objects - rows = result_set._convert_json_table(data) - - # Check that we got an empty list - assert rows == [] - - def test_convert_json_table_no_description( - self, mock_connection, mock_sea_client, execute_response - ): - """Test converting JSON data with no description.""" - execute_response.description = None - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - - # Sample data - data = [[1, "value1"], [2, "value2"]] - - # Convert to Row objects - rows = result_set._convert_json_table(data) - - # Check that we got the original data - assert rows == data - - def test_fetchone( - self, mock_connection, mock_sea_client, execute_response, mock_result_data - ): - """Test fetching one row.""" - # Create a result set with data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - ) - - # Replace the results queue with a JsonQueue containing test data - result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) - - # Fetch one row - row = result_set.fetchone() - - # Check that we got a Row object with the correct values - assert isinstance(row, Row) - assert row.col1 == 1 - assert row.col2 == "value1" - - # Check that the row index was updated - assert result_set._next_row_index == 1 - - def test_fetchone_empty(self, mock_connection, mock_sea_client, execute_response): - """Test fetching one row from an empty result set.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() - # Fetch one row - row = result_set.fetchone() + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) - # Check that we got None - assert row is None + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() - def test_fetchmany( - self, mock_connection, mock_sea_client, execute_response, mock_result_data - ): - """Test fetching multiple rows.""" - # Create a result set with data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - ) + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() - # Replace the results queue with a JsonQueue containing test data - result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) - # Fetch two rows - rows = result_set.fetchmany(2) + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() - # Check that we got two Row objects with the correct values - assert len(rows) == 2 - assert isinstance(rows[0], Row) - assert rows[0].col1 == 1 - assert rows[0].col2 == "value1" - assert rows[1].col1 == 2 - assert rows[1].col2 == "value2" + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) - # Check that the row index was updated - assert result_set._next_row_index == 2 + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass - def test_fetchmany_negative_size( + def test_fill_results_buffer_not_implemented( self, mock_connection, mock_sea_client, execute_response ): - """Test fetching with a negative size.""" + """Test that _fill_results_buffer raises NotImplementedError.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -343,126 +194,8 @@ def test_fetchmany_negative_size( arraysize=100, ) - # Try to fetch with a negative size with pytest.raises( - ValueError, match="size argument for fetchmany is -1 but must be >= 0" + NotImplementedError, + match="_fill_results_buffer is not implemented for SEA backend", ): - result_set.fetchmany(-1) - - def test_fetchall( - self, mock_connection, mock_sea_client, execute_response, mock_result_data - ): - """Test fetching all rows.""" - # Create a result set with data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - ) - - # Replace the results queue with a JsonQueue containing test data - result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) - - # Fetch all rows - rows = result_set.fetchall() - - # Check that we got three Row objects with the correct values - assert len(rows) == 3 - assert isinstance(rows[0], Row) - assert rows[0].col1 == 1 - assert rows[0].col2 == "value1" - assert rows[1].col1 == 2 - assert rows[1].col2 == "value2" - assert rows[2].col1 == 3 - assert rows[2].col2 == "value3" - - # Check that the row index was updated - assert result_set._next_row_index == 3 - - def test_fetchmany_json( - self, mock_connection, mock_sea_client, execute_response, mock_result_data - ): - """Test fetching JSON data directly.""" - # Create a result set with data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - ) - - # Replace the results queue with a JsonQueue containing test data - result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) - - # Fetch two rows as JSON - rows = result_set.fetchmany_json(2) - - # Check that we got the raw data - assert rows == [[1, "value1"], [2, "value2"]] - - # Check that the row index was updated - assert result_set._next_row_index == 2 - - def test_fetchall_json( - self, mock_connection, mock_sea_client, execute_response, mock_result_data - ): - """Test fetching all JSON data directly.""" - # Create a result set with data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - ) - - # Replace the results queue with a JsonQueue containing test data - result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) - - # Fetch all rows as JSON - rows = result_set.fetchall_json() - - # Check that we got the raw data - assert rows == [[1, "value1"], [2, "value2"], [3, "value3"]] - - # Check that the row index was updated - assert result_set._next_row_index == 3 - - def test_iteration( - self, mock_connection, mock_sea_client, execute_response, mock_result_data - ): - """Test iterating over the result set.""" - # Create a result set with data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - result_data=mock_result_data, - ) - - # Replace the results queue with a JsonQueue containing test data - result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]]) - - # Iterate over the result set - rows = list(result_set) - - # Check that we got three Row objects with the correct values - assert len(rows) == 3 - assert isinstance(rows[0], Row) - assert rows[0].col1 == 1 - assert rows[0].col2 == "value1" - assert rows[1].col1 == 2 - assert rows[1].col2 == "value2" - assert rows[2].col1 == 3 - assert rows[2].col2 == "value3" - - # Check that the row index was updated - assert result_set._next_row_index == 3 + result_set._fill_results_buffer() From 563da71e389ed7ad68c57b64c7d1eb97c746f57c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 07:10:02 +0000 Subject: [PATCH 169/204] introduce more integration tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 5 +- tests/e2e/test_driver.py | 76 +++++++-- tests/unit/test_sea_conversion.py | 214 +++++++++++++++++++++++++ tests/unit/test_sea_queue.py | 172 ++++++++++++++++++++ tests/unit/test_sea_result_set.py | 256 ++++++++++++++++++++++++------ 5 files changed, 656 insertions(+), 67 deletions(-) create mode 100644 tests/unit/test_sea_conversion.py create mode 100644 tests/unit/test_sea_queue.py diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 95c9c4823..71c78ce59 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -471,9 +471,9 @@ def __init__( manifest: Manifest from SEA response (optional) """ - self.results = None + results_queue = None if result_data: - self.results = SeaResultSetQueueFactory.build_queue( + results_queue = SeaResultSetQueueFactory.build_queue( result_data, manifest, str(execute_response.command_id.to_sea_statement_id()), @@ -492,6 +492,7 @@ def __init__( command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, + results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 9d0d0141e..f4a992529 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -196,6 +196,17 @@ def test_execute_async__long_running(self): assert result[0].asDict() == {"count(1)": 0} + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) def test_execute_async__small_result(self, extra_params): small_result_query = "SELECT 1" @@ -352,8 +363,8 @@ def test_create_table_will_return_empty_result_set(self, extra_params): finally: cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) - def test_get_tables(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_get_tables(self): + with self.cursor() as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] @@ -398,8 +409,8 @@ def test_get_tables(self, extra_params): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - def test_get_columns(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_get_columns(self): + with self.cursor() as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] @@ -521,8 +532,8 @@ def test_escape_single_quotes(self, extra_params): rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" - def test_get_schemas(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_get_schemas(self): + with self.cursor() as cursor: database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute("CREATE DATABASE IF NOT EXISTS {}".format(database_name)) @@ -539,8 +550,8 @@ def test_get_schemas(self, extra_params): finally: cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) - def test_get_catalogs(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_get_catalogs(self): + with self.cursor() as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description @@ -813,8 +824,21 @@ def test_ssp_passthrough(self): assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_timestamps_arrow(self): - with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_timestamps_arrow(self, extra_params): + with self.cursor( + {"session_configuration": {"ansi_mode": False}, **extra_params} + ) as cursor: for timestamp, expected in self.timestamp_and_expected_results: cursor.execute( "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) @@ -837,8 +861,21 @@ def test_timestamps_arrow(self): ), "timestamp {} did not match {}".format(timestamp, expected) @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_multi_timestamps_arrow(self): - with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_multi_timestamps_arrow(self, extra_params): + with self.cursor( + {"session_configuration": {"ansi_mode": False}, **extra_params} + ) as cursor: query, expected = self.multi_query() expected = [ [self.maybe_add_timezone_to_timestamp(ts) for ts in row] @@ -855,9 +892,20 @@ def test_multi_timestamps_arrow(self): assert result == expected @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - def test_timezone_with_timestamp(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_timezone_with_timestamp(self, extra_params): if self.should_add_timezone(): - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SET TIME ZONE 'Europe/Amsterdam'") cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") amsterdam = pytz.timezone("Europe/Amsterdam") diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py new file mode 100644 index 000000000..99c178ab7 --- /dev/null +++ b/tests/unit/test_sea_conversion.py @@ -0,0 +1,214 @@ +""" +Tests for the conversion module in the SEA backend. + +This module contains tests for the SqlType and SqlTypeConverter classes. +""" + +import pytest +import datetime +import decimal +from unittest.mock import Mock, patch + +from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter + + +class TestSqlType: + """Test suite for the SqlType class.""" + + def test_is_numeric(self): + """Test the is_numeric method.""" + assert SqlType.is_numeric(SqlType.BYTE) + assert SqlType.is_numeric(SqlType.SHORT) + assert SqlType.is_numeric(SqlType.INT) + assert SqlType.is_numeric(SqlType.LONG) + assert SqlType.is_numeric(SqlType.FLOAT) + assert SqlType.is_numeric(SqlType.DOUBLE) + assert SqlType.is_numeric(SqlType.DECIMAL) + + # Test with uppercase + assert SqlType.is_numeric("INT") + assert SqlType.is_numeric("DECIMAL") + + # Test non-numeric types + assert not SqlType.is_numeric(SqlType.STRING) + assert not SqlType.is_numeric(SqlType.BOOLEAN) + assert not SqlType.is_numeric(SqlType.DATE) + + def test_is_boolean(self): + """Test the is_boolean method.""" + assert SqlType.is_boolean(SqlType.BOOLEAN) + assert SqlType.is_boolean("BOOLEAN") + + # Test non-boolean types + assert not SqlType.is_boolean(SqlType.STRING) + assert not SqlType.is_boolean(SqlType.INT) + + def test_is_datetime(self): + """Test the is_datetime method.""" + assert SqlType.is_datetime(SqlType.DATE) + assert SqlType.is_datetime(SqlType.TIMESTAMP) + assert SqlType.is_datetime(SqlType.INTERVAL) + assert SqlType.is_datetime("DATE") + assert SqlType.is_datetime("TIMESTAMP") + + # Test non-datetime types + assert not SqlType.is_datetime(SqlType.STRING) + assert not SqlType.is_datetime(SqlType.INT) + + def test_is_string(self): + """Test the is_string method.""" + assert SqlType.is_string(SqlType.STRING) + assert SqlType.is_string(SqlType.CHAR) + assert SqlType.is_string("STRING") + assert SqlType.is_string("CHAR") + + # Test non-string types + assert not SqlType.is_string(SqlType.INT) + assert not SqlType.is_string(SqlType.BOOLEAN) + + def test_is_binary(self): + """Test the is_binary method.""" + assert SqlType.is_binary(SqlType.BINARY) + assert SqlType.is_binary("BINARY") + + # Test non-binary types + assert not SqlType.is_binary(SqlType.STRING) + assert not SqlType.is_binary(SqlType.INT) + + def test_is_complex(self): + """Test the is_complex method.""" + assert SqlType.is_complex(SqlType.ARRAY) + assert SqlType.is_complex(SqlType.MAP) + assert SqlType.is_complex(SqlType.STRUCT) + assert SqlType.is_complex(SqlType.USER_DEFINED_TYPE) + assert SqlType.is_complex("ARRAY") + assert SqlType.is_complex("MAP") + assert SqlType.is_complex("STRUCT") + + # Test non-complex types + assert not SqlType.is_complex(SqlType.STRING) + assert not SqlType.is_complex(SqlType.INT) + + +class TestSqlTypeConverter: + """Test suite for the SqlTypeConverter class.""" + + def test_convert_value_null(self): + """Test converting null values.""" + assert SqlTypeConverter.convert_value(None, SqlType.INT) is None + assert SqlTypeConverter.convert_value(None, SqlType.STRING) is None + assert SqlTypeConverter.convert_value(None, SqlType.BOOLEAN) is None + + def test_convert_numeric_types(self): + """Test converting numeric types.""" + # Test integer types + assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123 + assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456 + assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789 + assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890 + + # Test floating point types + assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45 + assert SqlTypeConverter.convert_value("678.90", SqlType.DOUBLE) == 678.90 + + # Test decimal type + decimal_value = SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test decimal with precision and scale + decimal_value = SqlTypeConverter.convert_value( + "123.45", SqlType.DECIMAL, precision=5, scale=2 + ) + assert isinstance(decimal_value, decimal.Decimal) + assert decimal_value == decimal.Decimal("123.45") + + # Test invalid numeric input + result = SqlTypeConverter.convert_value("not_a_number", SqlType.INT) + assert result == "not_a_number" # Returns original value on error + + def test_convert_boolean_type(self): + """Test converting boolean types.""" + # True values + assert SqlTypeConverter.convert_value("true", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("True", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("t", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("1", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN) is True + assert SqlTypeConverter.convert_value("y", SqlType.BOOLEAN) is True + + # False values + assert SqlTypeConverter.convert_value("false", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("False", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("f", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("0", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("no", SqlType.BOOLEAN) is False + assert SqlTypeConverter.convert_value("n", SqlType.BOOLEAN) is False + + def test_convert_datetime_types(self): + """Test converting datetime types.""" + # Test date type + date_value = SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE) + assert isinstance(date_value, datetime.date) + assert date_value == datetime.date(2023, 1, 15) + + # Test timestamp type + timestamp_value = SqlTypeConverter.convert_value( + "2023-01-15T12:30:45", SqlType.TIMESTAMP + ) + assert isinstance(timestamp_value, datetime.datetime) + assert timestamp_value.year == 2023 + assert timestamp_value.month == 1 + assert timestamp_value.day == 15 + assert timestamp_value.hour == 12 + assert timestamp_value.minute == 30 + assert timestamp_value.second == 45 + + # Test interval type (currently returns as string) + interval_value = SqlTypeConverter.convert_value( + "1 day 2 hours", SqlType.INTERVAL + ) + assert interval_value == "1 day 2 hours" + + # Test invalid date input + result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE) + assert result == "not_a_date" # Returns original value on error + + def test_convert_string_types(self): + """Test converting string types.""" + # String types don't need conversion, they should be returned as-is + assert ( + SqlTypeConverter.convert_value("test string", SqlType.STRING) + == "test string" + ) + assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char" + + def test_convert_binary_type(self): + """Test converting binary type.""" + # Test valid hex string + binary_value = SqlTypeConverter.convert_value("48656C6C6F", SqlType.BINARY) + assert isinstance(binary_value, bytes) + assert binary_value == b"Hello" + + # Test invalid binary input + result = SqlTypeConverter.convert_value("not_hex", SqlType.BINARY) + assert result == "not_hex" # Returns original value on error + + def test_convert_unsupported_type(self): + """Test converting an unsupported type.""" + # Should return the original value + assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test" + + # Complex types should return as-is + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.MAP) + == "complex_value" + ) + assert ( + SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT) + == "complex_value" + ) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py new file mode 100644 index 000000000..92b94402c --- /dev/null +++ b/tests/unit/test_sea_queue.py @@ -0,0 +1,172 @@ +""" +Tests for SEA-related queue classes in utils.py. + +This module contains tests for the JsonQueue and SeaResultSetQueueFactory classes. +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch + +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest + + +class TestJsonQueue: + """Test suite for the JsonQueue class.""" + + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", 1, True], + ["value2", 2, False], + ["value3", 3, True], + ["value4", 4, False], + ["value5", 5, True], + ] + + def test_init(self, sample_data): + """Test initialization of JsonQueue.""" + queue = JsonQueue(sample_data) + assert queue.data_array == sample_data + assert queue.cur_row_index == 0 + assert queue.n_valid_rows == len(sample_data) + + def test_next_n_rows_partial(self, sample_data): + """Test fetching a subset of rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(2) + assert result == sample_data[:2] + assert queue.cur_row_index == 2 + + def test_next_n_rows_all(self, sample_data): + """Test fetching all rows.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data)) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_more_than_available(self, sample_data): + """Test fetching more rows than available.""" + queue = JsonQueue(sample_data) + result = queue.next_n_rows(len(sample_data) + 10) + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_next_n_rows_after_partial(self, sample_data): + """Test fetching rows after a partial fetch.""" + queue = JsonQueue(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.next_n_rows(2) # Fetch next 2 rows + assert result == sample_data[2:4] + assert queue.cur_row_index == 4 + + def test_remaining_rows_all(self, sample_data): + """Test fetching all remaining rows at once.""" + queue = JsonQueue(sample_data) + result = queue.remaining_rows() + assert result == sample_data + assert queue.cur_row_index == len(sample_data) + + def test_remaining_rows_after_partial(self, sample_data): + """Test fetching remaining rows after a partial fetch.""" + queue = JsonQueue(sample_data) + queue.next_n_rows(2) # Fetch first 2 rows + result = queue.remaining_rows() # Fetch remaining rows + assert result == sample_data[2:] + assert queue.cur_row_index == len(sample_data) + + def test_empty_data(self): + """Test with empty data array.""" + queue = JsonQueue([]) + assert queue.next_n_rows(10) == [] + assert queue.remaining_rows() == [] + assert queue.cur_row_index == 0 + assert queue.n_valid_rows == 0 + + +class TestSeaResultSetQueueFactory: + """Test suite for the SeaResultSetQueueFactory class.""" + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + client = Mock() + client.max_download_threads = 10 + return client + + @pytest.fixture + def mock_description(self): + """Create a mock column description.""" + return [ + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), + ] + + def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): + """Test building a queue with inline JSON data.""" + # Create sample data for inline JSON result + data = [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ] + + # Create a ResultData object with inline data + result_data = ResultData(data=data, external_links=None, row_count=len(data)) + + # Create a manifest (not used for inline data) + manifest = None + + # Build the queue + queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) + + # Verify the queue is a JsonQueue with the correct data + assert isinstance(queue, JsonQueue) + assert queue.data_array == data + assert queue.n_valid_rows == len(data) + + def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): + """Test building a queue with empty data.""" + # Create a ResultData object with no data + result_data = ResultData(data=None, external_links=None, row_count=0) + + # Build the queue + queue = SeaResultSetQueueFactory.build_queue( + result_data, + None, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) + + # Verify the queue is a JsonQueue with empty data + assert isinstance(queue, JsonQueue) + assert queue.data_array == [] + assert queue.n_valid_rows == 0 + + def test_build_queue_with_external_links(self, mock_sea_client, mock_description): + """Test building a queue with external links raises NotImplementedError.""" + # Create a ResultData object with external links + result_data = ResultData( + data=None, external_links=["link1", "link2"], row_count=10 + ) + + # Verify that NotImplementedError is raised + with pytest.raises( + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", + ): + SeaResultSetQueueFactory.build_queue( + result_data, + None, + "test-statement-123", + description=mock_description, + sea_client=mock_sea_client, + ) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index c596dbc14..f8a36657a 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -8,8 +8,10 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.result_set import SeaResultSet +from databricks.sql.result_set import SeaResultSet, Row +from databricks.sql.utils import JsonQueue from databricks.sql.backend.types import CommandId, CommandState, BackendType +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest class TestSeaResultSet: @@ -37,11 +39,55 @@ def execute_response(self): mock_response.is_direct_results = False mock_response.results_queue = None mock_response.description = [ - ("test_value", "INT", None, None, None, None, None) + ("col1", "string", None, None, None, None, None), + ("col2", "int", None, None, None, None, None), + ("col3", "boolean", None, None, None, None, None), ] mock_response.is_staging_operation = False + mock_response.lz4_compressed = False + mock_response.arrow_schema_bytes = None return mock_response + @pytest.fixture + def sample_data(self): + """Create sample data for testing.""" + return [ + ["value1", "1", "true"], + ["value2", "2", "false"], + ["value3", "3", "true"], + ["value4", "4", "false"], + ["value5", "5", "true"], + ] + + @pytest.fixture + def result_set_with_data( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Create a SeaResultSet with sample data.""" + # Create ResultData with inline data + result_data = ResultData( + data=sample_data, external_links=None, row_count=len(sample_data) + ) + + # Initialize SeaResultSet with result data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + result_data=result_data, + manifest=None, + ) + result_set.results = JsonQueue(sample_data) + + return result_set + + @pytest.fixture + def json_queue(self, sample_data): + """Create a JsonQueue with sample data.""" + return JsonQueue(sample_data) + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -123,10 +169,139 @@ def test_close_when_connection_closed( assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_unimplemented_methods( - self, mock_connection, mock_sea_client, execute_response + def test_init_with_result_data(self, result_set_with_data, sample_data): + """Test initializing SeaResultSet with result data.""" + # Verify the results queue was created correctly + assert isinstance(result_set_with_data.results, JsonQueue) + assert result_set_with_data.results.data_array == sample_data + assert result_set_with_data.results.n_valid_rows == len(sample_data) + + def test_convert_json_types(self, result_set_with_data, sample_data): + """Test the _convert_json_types method.""" + # Call _convert_json_types + converted_rows = result_set_with_data._convert_json_types(sample_data) + + # Verify the conversion + assert len(converted_rows) == len(sample_data) + assert converted_rows[0][0] == "value1" # string stays as string + assert converted_rows[0][1] == 1 # "1" converted to int + assert converted_rows[0][2] is True # "true" converted to boolean + + def test_create_json_table(self, result_set_with_data, sample_data): + """Test the _create_json_table method.""" + # Call _create_json_table + result_rows = result_set_with_data._create_json_table(sample_data) + + # Verify the result + assert len(result_rows) == len(sample_data) + assert isinstance(result_rows[0], Row) + assert result_rows[0].col1 == "value1" + assert result_rows[0].col2 == 1 + assert result_rows[0].col3 is True + + def test_fetchmany_json(self, result_set_with_data): + """Test the fetchmany_json method.""" + # Test fetching a subset of rows + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 2 + + # Test fetching the next subset + result = result_set_with_data.fetchmany_json(2) + assert len(result) == 2 + assert result_set_with_data._next_row_index == 4 + + # Test fetching more than available + result = result_set_with_data.fetchmany_json(10) + assert len(result) == 1 # Only one row left + assert result_set_with_data._next_row_index == 5 + + def test_fetchall_json(self, result_set_with_data, sample_data): + """Test the fetchall_json method.""" + # Test fetching all rows + result = result_set_with_data.fetchall_json() + assert result == sample_data + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + result = result_set_with_data.fetchall_json() + assert result == [] + assert result_set_with_data._next_row_index == len(sample_data) + + def test_fetchone(self, result_set_with_data): + """Test the fetchone method.""" + # Test fetching one row at a time + row1 = result_set_with_data.fetchone() + assert isinstance(row1, Row) + assert row1.col1 == "value1" + assert row1.col2 == 1 + assert row1.col3 is True + assert result_set_with_data._next_row_index == 1 + + row2 = result_set_with_data.fetchone() + assert isinstance(row2, Row) + assert row2.col1 == "value2" + assert row2.col2 == 2 + assert row2.col3 is False + assert result_set_with_data._next_row_index == 2 + + # Fetch the rest + result_set_with_data.fetchall() + + # Test fetching when no more rows + row_none = result_set_with_data.fetchone() + assert row_none is None + + def test_fetchmany(self, result_set_with_data): + """Test the fetchmany method.""" + # Test fetching multiple rows + rows = result_set_with_data.fetchmany(2) + assert len(rows) == 2 + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert rows[1].col1 == "value2" + assert rows[1].col2 == 2 + assert rows[1].col3 is False + assert result_set_with_data._next_row_index == 2 + + # Test with invalid size + with pytest.raises( + ValueError, match="size argument for fetchmany is -1 but must be >= 0" + ): + result_set_with_data.fetchmany(-1) + + def test_fetchall(self, result_set_with_data, sample_data): + """Test the fetchall method.""" + # Test fetching all rows + rows = result_set_with_data.fetchall() + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + assert result_set_with_data._next_row_index == len(sample_data) + + # Test fetching again (should return empty) + rows = result_set_with_data.fetchall() + assert len(rows) == 0 + + def test_iteration(self, result_set_with_data, sample_data): + """Test iterating over the result set.""" + # Test iteration + rows = list(result_set_with_data) + assert len(rows) == len(sample_data) + assert isinstance(rows[0], Row) + assert rows[0].col1 == "value1" + assert rows[0].col2 == 1 + assert rows[0].col3 is True + + def test_fetchmany_arrow_not_implemented( + self, mock_connection, mock_sea_client, execute_response, sample_data ): - """Test that unimplemented methods raise NotImplementedError.""" + """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" + # Create a result set without JSON data result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -135,57 +310,39 @@ def test_unimplemented_methods( arraysize=100, ) - # Test each unimplemented method individually with specific error messages - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - result_set.fetchone() - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - result_set.fetchmany(10) - - with pytest.raises( - NotImplementedError, match="fetchmany is not implemented for SEA backend" - ): - # Test with default parameter value - result_set.fetchmany() - + # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchall is not implemented for SEA backend" - ): - result_set.fetchall() - - with pytest.raises( - NotImplementedError, - match="fetchmany_arrow is not implemented for SEA backend", + NotImplementedError, match="fetchmany_arrow only supported for JSON data" ): result_set.fetchmany_arrow(10) - with pytest.raises( - NotImplementedError, - match="fetchall_arrow is not implemented for SEA backend", - ): - result_set.fetchall_arrow() - - with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" - ): - # Test iteration protocol (calls fetchone internally) - next(iter(result_set)) + def test_fetchall_arrow_not_implemented( + self, mock_connection, mock_sea_client, execute_response, sample_data + ): + """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchone is not implemented for SEA backend" + NotImplementedError, match="fetchall_arrow only supported for JSON data" ): - # Test using the result set in a for loop - for row in result_set: - pass + result_set.fetchall_arrow() - def test_fill_results_buffer_not_implemented( + def test_is_staging_operation( self, mock_connection, mock_sea_client, execute_response ): - """Test that _fill_results_buffer raises NotImplementedError.""" + """Test the is_staging_operation property.""" + # Set is_staging_operation to True + execute_response.is_staging_operation = True + + # Create a result set result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, @@ -194,8 +351,5 @@ def test_fill_results_buffer_not_implemented( arraysize=100, ) - with pytest.raises( - NotImplementedError, - match="_fill_results_buffer is not implemented for SEA backend", - ): - result_set._fill_results_buffer() + # Test the property + assert result_set.is_staging_operation is True From a01827347403db3164277b675d921b02213bfffe Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 07:13:45 +0000 Subject: [PATCH 170/204] remove SEA tests in parameterized queries Signed-off-by: varun-edachali-dbx --- tests/e2e/test_parameterized_queries.py | 126 +++--------------------- 1 file changed, 16 insertions(+), 110 deletions(-) diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index e696c667b..686178ffa 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -405,20 +405,9 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) "Consider using native parameters." not in caplog.text ), "Log message should not be supressed" - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_positional_native_params_with_defaults(self, extra_params): + def test_positional_native_params_with_defaults(self): query = "SELECT ? col" - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: result = cursor.execute(query, parameters=[1]).fetchone() assert result.col == 1 @@ -434,22 +423,10 @@ def test_positional_native_params_with_defaults(self, extra_params): ["foo", "bar", "baz"], ), ) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_positional_native_multiple(self, params, extra_params): + def test_positional_native_multiple(self, params): query = "SELECT ? `foo`, ? `bar`, ? `baz`" - combined_params = {"use_inline_params": False, **extra_params} - with self.cursor(extra_params=combined_params) as cursor: + with self.cursor(extra_params={"use_inline_params": False}) as cursor: result = cursor.execute(query, params).fetchone() expected = [i.value if isinstance(i, DbsqlParameterBase) else i for i in params] @@ -457,19 +434,8 @@ def test_positional_native_multiple(self, params, extra_params): assert set(outcome) == set(expected) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_readme_example(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_readme_example(self): + with self.cursor() as cursor: result = cursor.execute( "SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"} ).fetchall() @@ -533,23 +499,11 @@ def test_native_recursive_complex_type( class TestInlineParameterSyntax(PySQLPytestTestCase): """The inline parameter approach uses pyformat markers""" - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_params_as_dict(self, extra_params): + def test_params_as_dict(self): query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz" params = {"foo": 1, "bar": 2, "baz": 3} - combined_params = {"use_inline_params": True, **extra_params} - with self.connection(extra_params=combined_params) as conn: + with self.connection(extra_params={"use_inline_params": True}) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() @@ -557,18 +511,7 @@ def test_params_as_dict(self, extra_params): assert result.bar == 2 assert result.baz == 3 - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_params_as_sequence(self, extra_params): + def test_params_as_sequence(self): """One side-effect of ParamEscaper using Python string interpolation to inline the values is that it can work with "ordinal" parameters, but only if a user writes parameter markers that are not defined with PEP-249. This test exists to prove that it works in the ideal case. @@ -578,8 +521,7 @@ def test_params_as_sequence(self, extra_params): query = "SELECT %s foo, %s bar, %s baz" params = (1, 2, 3) - combined_params = {"use_inline_params": True, **extra_params} - with self.connection(extra_params=combined_params) as conn: + with self.connection(extra_params={"use_inline_params": True}) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.foo == 1 @@ -599,18 +541,7 @@ def test_inline_ordinals_can_break_sql(self): ): cursor.execute(query, parameters=params) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_inline_named_dont_break_sql(self, extra_params): + def test_inline_named_dont_break_sql(self): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test just proves that's the case. @@ -620,30 +551,17 @@ def test_inline_named_dont_break_sql(self, extra_params): SELECT col_1 FROM base WHERE col_1 LIKE CONCAT(%(one)s, 'onite') """ params = {"one": "%(one)s"} - combined_params = {"use_inline_params": True, **extra_params} - with self.cursor(extra_params=combined_params) as cursor: + with self.cursor(extra_params={"use_inline_params": True}) as cursor: result = cursor.execute(query, parameters=params).fetchone() print("hello") - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_native_ordinals_dont_break_sql(self, extra_params): + def test_native_ordinals_dont_break_sql(self): """This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal parameters work in native mode for the exact same query, if we use the right marker `?` """ query = "SELECT 'samsonite', ? WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] - combined_params = {"use_inline_params": False, **extra_params} - with self.cursor(extra_params=combined_params) as cursor: + with self.cursor(extra_params={"use_inline_params": False}) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.samsonite == "samsonite" @@ -659,25 +577,13 @@ def test_inline_like_wildcard_breaks(self): with pytest.raises(ValueError, match="unsupported format character"): result = cursor.execute(query, parameters=params).fetchone() - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_native_like_wildcard_works(self, extra_params): + def test_native_like_wildcard_works(self): """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE wildcards work under the native approach. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" params = {"param": "bar"} - combined_params = {"use_inline_params": False, **extra_params} - with self.cursor(extra_params=combined_params) as cursor: + with self.cursor(extra_params={"use_inline_params": False}) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.col == 1 From c0e98f4c3098e154b6e5cee63e8fcab169a8b776 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 07:15:50 +0000 Subject: [PATCH 171/204] remove partial parameter fix changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 9 ++++----- tests/unit/test_sea_backend.py | 6 +----- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index d3a90ed10..0c0400ae2 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -14,7 +14,6 @@ WaitTimeout, MetadataCommands, ) -from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -406,7 +405,7 @@ def execute_command( lz4_compression: bool, cursor: Cursor, use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List[Dict[str, Any]], async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union[SeaResultSet, None]: @@ -440,9 +439,9 @@ def execute_command( for param in parameters: sea_parameters.append( StatementParameter( - name=param.name, - value=param.value.stringValue, - type=param.type, + name=param["name"], + value=param["value"], + type=param["type"] if "type" in param else None, ) ) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index b79aaa093..bc6768d2b 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -355,11 +355,7 @@ def test_command_execution_advanced( "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - param = Mock() - param.name = "param1" - param.value = Mock() - param.value.stringValue = "value1" - param.type = "STRING" + param = {"name": "param1", "value": "value1", "type": "STRING"} with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( From 7343035945561a4785bf9bdd73b2c13ddc33a5cf Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 08:34:22 +0000 Subject: [PATCH 172/204] remove un-necessary timestamp tests (pass with minor disparity) Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 43 ++++------------------------------------ 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index f4a992529..49ac1503c 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -496,17 +496,6 @@ def test_get_columns(self): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) def test_escape_single_quotes(self, extra_params): with self.cursor(extra_params) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) @@ -824,21 +813,8 @@ def test_ssp_passthrough(self): assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_timestamps_arrow(self, extra_params): - with self.cursor( - {"session_configuration": {"ansi_mode": False}, **extra_params} - ) as cursor: + def test_timestamps_arrow(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: cursor.execute( "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) @@ -892,20 +868,9 @@ def test_multi_timestamps_arrow(self, extra_params): assert result == expected @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") - @pytest.mark.parametrize( - "extra_params", - [ - {}, - { - "use_sea": True, - "use_cloud_fetch": False, - "enable_query_result_lz4_compression": False, - }, - ], - ) - def test_timezone_with_timestamp(self, extra_params): + def test_timezone_with_timestamp(self): if self.should_add_timezone(): - with self.cursor(extra_params) as cursor: + with self.cursor() as cursor: cursor.execute("SET TIME ZONE 'Europe/Amsterdam'") cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") amsterdam = pytz.timezone("Europe/Amsterdam") From ec500b620c9bbf84fa381a779c77dae685e2c208 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 08:37:22 +0000 Subject: [PATCH 173/204] slightly stronger typing of _convert_json_types Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 71c78ce59..06f98c88c 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -510,7 +510,7 @@ def _convert_json_to_arrow(self, rows): names = [col[0] for col in self.description] return pyarrow.Table.from_arrays(columns, names=names) - def _convert_json_types(self, rows): + def _convert_json_types(self, rows: List) -> List: """ Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. From 0b3e91d612fa7528eb0d5498e5b81998b8425494 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 08:38:09 +0000 Subject: [PATCH 174/204] stronger typing of json utility func s Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 06f98c88c..64fd9cbed 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -499,7 +499,7 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", ) - def _convert_json_to_arrow(self, rows): + def _convert_json_to_arrow(self, rows: List) -> "pyarrow.Table": """ Convert raw data rows to Arrow table. """ @@ -544,7 +544,7 @@ def _convert_json_types(self, rows: List) -> List: return converted_rows - def _create_json_table(self, rows): + def _create_json_table(self, rows: List) -> List[Row]: """ Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. From 7664e44f2de52a2d50901b8f80af45732d4dc04c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 08:39:13 +0000 Subject: [PATCH 175/204] stronger typing of fetch*_json Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 64fd9cbed..ec4c0aadb 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -562,7 +562,7 @@ def _create_json_table(self, rows: List) -> List[Row]: return [ResultRow(*row) for row in rows] - def fetchmany_json(self, size: int): + def fetchmany_json(self, size: int) -> List: """ Fetch the next set of rows as a columnar table. @@ -583,7 +583,7 @@ def fetchmany_json(self, size: int): return results - def fetchall_json(self): + def fetchall_json(self) -> List: """ Fetch all remaining rows as a columnar table. From db7b8e57ec07e079b6d5897840a653996e0f464c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 08:41:34 +0000 Subject: [PATCH 176/204] remove unused helper methods in SqlType Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/conversion.py | 44 ------ tests/unit/test_conversion.py | 158 ------------------- tests/unit/test_sea_conversion.py | 78 --------- 3 files changed, 280 deletions(-) delete mode 100644 tests/unit/test_conversion.py diff --git a/src/databricks/sql/backend/sea/conversion.py b/src/databricks/sql/backend/sea/conversion.py index 019305bf0..a3edd6dcc 100644 --- a/src/databricks/sql/backend/sea/conversion.py +++ b/src/databricks/sql/backend/sea/conversion.py @@ -55,50 +55,6 @@ class SqlType: NULL = "null" USER_DEFINED_TYPE = "user_defined_type" - @classmethod - def is_numeric(cls, sql_type: str) -> bool: - """Check if the SQL type is a numeric type.""" - return sql_type.lower() in ( - cls.BYTE, - cls.SHORT, - cls.INT, - cls.LONG, - cls.FLOAT, - cls.DOUBLE, - cls.DECIMAL, - ) - - @classmethod - def is_boolean(cls, sql_type: str) -> bool: - """Check if the SQL type is a boolean type.""" - return sql_type.lower() == cls.BOOLEAN - - @classmethod - def is_datetime(cls, sql_type: str) -> bool: - """Check if the SQL type is a date/time type.""" - return sql_type.lower() in (cls.DATE, cls.TIMESTAMP, cls.INTERVAL) - - @classmethod - def is_string(cls, sql_type: str) -> bool: - """Check if the SQL type is a string type.""" - return sql_type.lower() in (cls.CHAR, cls.STRING) - - @classmethod - def is_binary(cls, sql_type: str) -> bool: - """Check if the SQL type is a binary type.""" - return sql_type.lower() == cls.BINARY - - @classmethod - def is_complex(cls, sql_type: str) -> bool: - """Check if the SQL type is a complex type.""" - sql_type = sql_type.lower() - return ( - sql_type.startswith(cls.ARRAY) - or sql_type.startswith(cls.MAP) - or sql_type.startswith(cls.STRUCT) - or sql_type == cls.USER_DEFINED_TYPE - ) - class SqlTypeConverter: """ diff --git a/tests/unit/test_conversion.py b/tests/unit/test_conversion.py deleted file mode 100644 index 656e6730a..000000000 --- a/tests/unit/test_conversion.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -Unit tests for the type conversion utilities. -""" - -import unittest -from datetime import date, datetime, time -from decimal import Decimal - -from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter - - -class TestSqlType(unittest.TestCase): - """Tests for the SqlType class.""" - - def test_is_numeric(self): - """Test the is_numeric method.""" - self.assertTrue(SqlType.is_numeric(SqlType.INT)) - self.assertTrue(SqlType.is_numeric(SqlType.BYTE)) - self.assertTrue(SqlType.is_numeric(SqlType.SHORT)) - self.assertTrue(SqlType.is_numeric(SqlType.LONG)) - self.assertTrue(SqlType.is_numeric(SqlType.FLOAT)) - self.assertTrue(SqlType.is_numeric(SqlType.DOUBLE)) - self.assertTrue(SqlType.is_numeric(SqlType.DECIMAL)) - self.assertFalse(SqlType.is_numeric(SqlType.BOOLEAN)) - self.assertFalse(SqlType.is_numeric(SqlType.STRING)) - self.assertFalse(SqlType.is_numeric(SqlType.DATE)) - - def test_is_boolean(self): - """Test the is_boolean method.""" - self.assertTrue(SqlType.is_boolean(SqlType.BOOLEAN)) - self.assertFalse(SqlType.is_boolean(SqlType.INT)) - self.assertFalse(SqlType.is_boolean(SqlType.STRING)) - - def test_is_datetime(self): - """Test the is_datetime method.""" - self.assertTrue(SqlType.is_datetime(SqlType.DATE)) - self.assertTrue(SqlType.is_datetime(SqlType.TIMESTAMP)) - self.assertTrue(SqlType.is_datetime(SqlType.INTERVAL)) - self.assertFalse(SqlType.is_datetime(SqlType.INT)) - self.assertFalse(SqlType.is_datetime(SqlType.STRING)) - - def test_is_string(self): - """Test the is_string method.""" - self.assertTrue(SqlType.is_string(SqlType.CHAR)) - self.assertTrue(SqlType.is_string(SqlType.STRING)) - self.assertFalse(SqlType.is_string(SqlType.INT)) - self.assertFalse(SqlType.is_string(SqlType.DATE)) - - def test_is_binary(self): - """Test the is_binary method.""" - self.assertTrue(SqlType.is_binary(SqlType.BINARY)) - self.assertFalse(SqlType.is_binary(SqlType.INT)) - self.assertFalse(SqlType.is_binary(SqlType.STRING)) - - def test_is_complex(self): - """Test the is_complex method.""" - self.assertTrue(SqlType.is_complex("array")) - self.assertTrue(SqlType.is_complex("map")) - self.assertTrue(SqlType.is_complex("struct")) - self.assertFalse(SqlType.is_complex(SqlType.INT)) - self.assertFalse(SqlType.is_complex(SqlType.STRING)) - - -class TestSqlTypeConverter(unittest.TestCase): - """Tests for the SqlTypeConverter class.""" - - def test_numeric_conversions(self): - """Test numeric type conversions.""" - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.INT), 123) - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.BYTE), 123) - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.SHORT), 123) - self.assertEqual(SqlTypeConverter.convert_value("123", SqlType.LONG), 123) - self.assertEqual( - SqlTypeConverter.convert_value("123.45", SqlType.FLOAT), 123.45 - ) - self.assertEqual( - SqlTypeConverter.convert_value("123.45", SqlType.DOUBLE), 123.45 - ) - self.assertEqual( - SqlTypeConverter.convert_value("123.45", SqlType.DECIMAL), Decimal("123.45") - ) - - # Test decimal with precision and scale - self.assertEqual( - SqlTypeConverter.convert_value( - "123.456", SqlType.DECIMAL, precision=5, scale=2 - ), - Decimal("123.46"), # Rounded to scale 2 - ) - - def test_boolean_conversions(self): - """Test boolean type conversions.""" - self.assertTrue(SqlTypeConverter.convert_value("true", SqlType.BOOLEAN)) - self.assertTrue(SqlTypeConverter.convert_value("TRUE", SqlType.BOOLEAN)) - self.assertTrue(SqlTypeConverter.convert_value("1", SqlType.BOOLEAN)) - self.assertTrue(SqlTypeConverter.convert_value("yes", SqlType.BOOLEAN)) - self.assertFalse(SqlTypeConverter.convert_value("false", SqlType.BOOLEAN)) - self.assertFalse(SqlTypeConverter.convert_value("FALSE", SqlType.BOOLEAN)) - self.assertFalse(SqlTypeConverter.convert_value("0", SqlType.BOOLEAN)) - self.assertFalse(SqlTypeConverter.convert_value("no", SqlType.BOOLEAN)) - - def test_datetime_conversions(self): - """Test date/time type conversions.""" - self.assertEqual( - SqlTypeConverter.convert_value("2023-01-15", SqlType.DATE), - date(2023, 1, 15), - ) - self.assertEqual( - SqlTypeConverter.convert_value("2023-01-15 14:30:45", SqlType.TIMESTAMP), - datetime(2023, 1, 15, 14, 30, 45), - ) - - def test_string_conversions(self): - """Test string type conversions.""" - self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.STRING), "test") - self.assertEqual(SqlTypeConverter.convert_value("test", SqlType.CHAR), "test") - - def test_binary_conversions(self): - """Test binary type conversions.""" - hex_str = "68656c6c6f" # "hello" in hex - expected_bytes = b"hello" - - self.assertEqual( - SqlTypeConverter.convert_value(hex_str, SqlType.BINARY), expected_bytes - ) - - def test_error_handling(self): - """Test error handling in conversions.""" - self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.INT), "abc") - self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.FLOAT), "abc") - self.assertEqual(SqlTypeConverter.convert_value("abc", SqlType.DECIMAL), "abc") - - def test_null_handling(self): - """Test handling of NULL values.""" - self.assertIsNone(SqlTypeConverter.convert_value(None, SqlType.INT)) - self.assertIsNone(SqlTypeConverter.convert_value(None, SqlType.STRING)) - self.assertIsNone(SqlTypeConverter.convert_value(None, SqlType.DATE)) - - def test_complex_type_handling(self): - """Test handling of complex types.""" - # Complex types should be returned as-is for now - self.assertEqual( - SqlTypeConverter.convert_value('{"a": 1}', "array"), '{"a": 1}' - ) - self.assertEqual( - SqlTypeConverter.convert_value('{"a": 1}', "map"), '{"a": 1}' - ) - self.assertEqual( - SqlTypeConverter.convert_value('{"a": 1}', "struct"), '{"a": 1}' - ) - self.assertEqual( - SqlTypeConverter.convert_value('{"a": 1}', SqlType.USER_DEFINED_TYPE), - '{"a": 1}', - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py index 99c178ab7..738889975 100644 --- a/tests/unit/test_sea_conversion.py +++ b/tests/unit/test_sea_conversion.py @@ -12,84 +12,6 @@ from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter -class TestSqlType: - """Test suite for the SqlType class.""" - - def test_is_numeric(self): - """Test the is_numeric method.""" - assert SqlType.is_numeric(SqlType.BYTE) - assert SqlType.is_numeric(SqlType.SHORT) - assert SqlType.is_numeric(SqlType.INT) - assert SqlType.is_numeric(SqlType.LONG) - assert SqlType.is_numeric(SqlType.FLOAT) - assert SqlType.is_numeric(SqlType.DOUBLE) - assert SqlType.is_numeric(SqlType.DECIMAL) - - # Test with uppercase - assert SqlType.is_numeric("INT") - assert SqlType.is_numeric("DECIMAL") - - # Test non-numeric types - assert not SqlType.is_numeric(SqlType.STRING) - assert not SqlType.is_numeric(SqlType.BOOLEAN) - assert not SqlType.is_numeric(SqlType.DATE) - - def test_is_boolean(self): - """Test the is_boolean method.""" - assert SqlType.is_boolean(SqlType.BOOLEAN) - assert SqlType.is_boolean("BOOLEAN") - - # Test non-boolean types - assert not SqlType.is_boolean(SqlType.STRING) - assert not SqlType.is_boolean(SqlType.INT) - - def test_is_datetime(self): - """Test the is_datetime method.""" - assert SqlType.is_datetime(SqlType.DATE) - assert SqlType.is_datetime(SqlType.TIMESTAMP) - assert SqlType.is_datetime(SqlType.INTERVAL) - assert SqlType.is_datetime("DATE") - assert SqlType.is_datetime("TIMESTAMP") - - # Test non-datetime types - assert not SqlType.is_datetime(SqlType.STRING) - assert not SqlType.is_datetime(SqlType.INT) - - def test_is_string(self): - """Test the is_string method.""" - assert SqlType.is_string(SqlType.STRING) - assert SqlType.is_string(SqlType.CHAR) - assert SqlType.is_string("STRING") - assert SqlType.is_string("CHAR") - - # Test non-string types - assert not SqlType.is_string(SqlType.INT) - assert not SqlType.is_string(SqlType.BOOLEAN) - - def test_is_binary(self): - """Test the is_binary method.""" - assert SqlType.is_binary(SqlType.BINARY) - assert SqlType.is_binary("BINARY") - - # Test non-binary types - assert not SqlType.is_binary(SqlType.STRING) - assert not SqlType.is_binary(SqlType.INT) - - def test_is_complex(self): - """Test the is_complex method.""" - assert SqlType.is_complex(SqlType.ARRAY) - assert SqlType.is_complex(SqlType.MAP) - assert SqlType.is_complex(SqlType.STRUCT) - assert SqlType.is_complex(SqlType.USER_DEFINED_TYPE) - assert SqlType.is_complex("ARRAY") - assert SqlType.is_complex("MAP") - assert SqlType.is_complex("STRUCT") - - # Test non-complex types - assert not SqlType.is_complex(SqlType.STRING) - assert not SqlType.is_complex(SqlType.INT) - - class TestSqlTypeConverter: """Test suite for the SqlTypeConverter class.""" From f75f2b53c735b23a5a471ef4e4374b4f7330b053 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 08:57:12 +0000 Subject: [PATCH 177/204] line breaks after multi line pydocs, remove excess logs Signed-off-by: varun-edachali-dbx --- examples/experimental/tests/test_sea_async_query.py | 6 ------ src/databricks/sql/result_set.py | 9 +++++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 1685ac4ca..3c0e325fe 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -82,9 +82,6 @@ def test_sea_async_query_with_cloud_fetch(): results.extend(cursor.fetchmany(10)) results.extend(cursor.fetchall()) actual_row_count = len(results) - logger.info( - f"{actual_row_count} rows retrieved against {requested_row_count} requested" - ) logger.info( f"Requested {requested_row_count} rows, received {actual_row_count} rows" @@ -188,9 +185,6 @@ def test_sea_async_query_without_cloud_fetch(): results.extend(cursor.fetchmany(10)) results.extend(cursor.fetchall()) actual_row_count = len(results) - logger.info( - f"{actual_row_count} rows retrieved against {requested_row_count} requested" - ) logger.info( f"Requested {requested_row_count} rows, received {actual_row_count} rows" diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ec4c0aadb..b1e067ad1 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -515,6 +515,7 @@ def _convert_json_types(self, rows: List) -> List: Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. """ + if not self.description or not rows: return rows @@ -554,6 +555,7 @@ def _create_json_table(self, rows: List) -> List[Row]: Returns: List of Row objects with named columns and converted values """ + if not self.description or not rows: return rows @@ -575,6 +577,7 @@ def fetchmany_json(self, size: int) -> List: Raises: ValueError: If size is negative """ + if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") @@ -590,6 +593,7 @@ def fetchall_json(self) -> List: Returns: Columnar table containing all remaining rows """ + results = self.results.remaining_rows() self._next_row_index += len(results) @@ -609,6 +613,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ImportError: If PyArrow is not installed ValueError: If size is negative """ + if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") @@ -625,6 +630,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": """ Fetch all remaining rows as an Arrow table. """ + if not isinstance(self.results, JsonQueue): raise NotImplementedError("fetchall_arrow only supported for JSON data") @@ -642,6 +648,7 @@ def fetchone(self) -> Optional[Row]: Returns: A single Row object or None if no more rows are available """ + if isinstance(self.results, JsonQueue): res = self._create_json_table(self.fetchmany_json(1)) else: @@ -662,6 +669,7 @@ def fetchmany(self, size: int) -> List[Row]: Raises: ValueError: If size is negative """ + if isinstance(self.results, JsonQueue): return self._create_json_table(self.fetchmany_json(size)) else: @@ -674,6 +682,7 @@ def fetchall(self) -> List[Row]: Returns: List of Row objects containing all remaining rows """ + if isinstance(self.results, JsonQueue): return self._create_json_table(self.fetchall_json()) else: From e2d4ef5767c3255ba1075a4ec9155c6dd4d2b5cd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 09:03:56 +0000 Subject: [PATCH 178/204] line breaks after multi line pydocs, reduce diff of redundant changes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 1 + tests/e2e/test_driver.py | 10 +++++----- tests/e2e/test_parameterized_queries.py | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index b1e067ad1..5eb529a83 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -503,6 +503,7 @@ def _convert_json_to_arrow(self, rows: List) -> "pyarrow.Table": """ Convert raw data rows to Arrow table. """ + columns = [] num_cols = len(rows[0]) for i in range(num_cols): diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 49ac1503c..476066e2c 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -496,8 +496,8 @@ def test_get_columns(self): for table in table_names: cursor.execute("DROP TABLE IF EXISTS {}".format(table)) - def test_escape_single_quotes(self, extra_params): - with self.cursor(extra_params) as cursor: + def test_escape_single_quotes(self): + with self.cursor({}) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) # Test escape syntax directly cursor.execute( @@ -522,7 +522,7 @@ def test_escape_single_quotes(self, extra_params): assert rows[0]["col_1"] == "you're" def test_get_schemas(self): - with self.cursor() as cursor: + with self.cursor({}) as cursor: database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute("CREATE DATABASE IF NOT EXISTS {}".format(database_name)) @@ -540,7 +540,7 @@ def test_get_schemas(self): cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) def test_get_catalogs(self): - with self.cursor() as cursor: + with self.cursor({}) as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description @@ -591,7 +591,7 @@ def test_unicode(self, extra_params): assert results[0][0] == unicode_str def test_cancel_during_execute(self): - with self.cursor() as cursor: + with self.cursor({}) as cursor: def execute_really_long_query(): cursor.execute( diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 686178ffa..79def9b72 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -2,7 +2,6 @@ from contextlib import contextmanager from decimal import Decimal from enum import Enum -import json from typing import Dict, List, Type, Union from unittest.mock import patch From 21e30783c16085ac76aaf5ab9508105b06df64c6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 09:05:17 +0000 Subject: [PATCH 179/204] reduce diff of redundant changes Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 476066e2c..5848d780b 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -364,7 +364,7 @@ def test_create_table_will_return_empty_result_set(self, extra_params): cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) def test_get_tables(self): - with self.cursor() as cursor: + with self.cursor({}) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] @@ -410,7 +410,7 @@ def test_get_tables(self): cursor.execute("DROP TABLE IF EXISTS {}".format(table)) def test_get_columns(self): - with self.cursor() as cursor: + with self.cursor({}) as cursor: table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) table_names = [table_name + "_1", table_name + "_2"] From bb015e6f2ae901c2ec1c1070bb61459e3101e33c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 09:23:08 +0000 Subject: [PATCH 180/204] mandate ResultData in SeaResultSet constructor Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 4 +- .../sql/backend/sea/utils/filters.py | 2 +- src/databricks/sql/result_set.py | 38 +++++++------- src/databricks/sql/utils.py | 7 ++- tests/unit/test_sea_queue.py | 2 +- tests/unit/test_sea_result_set.py | 52 +++++++++++-------- 6 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 0c0400ae2..2ed248c3d 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -616,10 +616,10 @@ def get_execution_result( connection=cursor.connection, execute_response=execute_response, sea_client=self, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, result_data=response.result, manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 1b7660829..f3bf4669a 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -77,9 +77,9 @@ def _filter_sea_result_set( connection=result_set.connection, execute_response=execute_response, sea_client=cast(SeaDatabricksClient, result_set.backend), + result_data=result_data, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, - result_data=result_data, ) return filtered_result_set diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 5eb529a83..a4814db57 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import List, Optional, TYPE_CHECKING @@ -450,13 +452,13 @@ class SeaResultSet(ResultSet): def __init__( self, - connection: "Connection", - execute_response: "ExecuteResponse", - sea_client: "SeaDatabricksClient", + connection: Connection, + execute_response: ExecuteResponse, + sea_client: SeaDatabricksClient, + result_data: ResultData, + manifest: Optional[ResultManifest] = None, buffer_size_bytes: int = 104857600, arraysize: int = 10000, - result_data: Optional["ResultData"] = None, - manifest: Optional["ResultManifest"] = None, ): """ Initialize a SeaResultSet with the response from a SEA query execution. @@ -467,21 +469,19 @@ def __init__( sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch - result_data: Result data from SEA response (optional) - manifest: Manifest from SEA response (optional) + result_data: Result data from SEA response + manifest: Manifest from SEA response """ - results_queue = None - if result_data: - results_queue = SeaResultSetQueueFactory.build_queue( - result_data, - manifest, - str(execute_response.command_id.to_sea_statement_id()), - description=execute_response.description, - max_download_threads=sea_client.max_download_threads, - sea_client=sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + manifest, + str(execute_response.command_id.to_sea_statement_id()), + description=execute_response.description, + max_download_threads=sea_client.max_download_threads, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) # Call parent constructor with common attributes super().__init__( @@ -503,6 +503,8 @@ def _convert_json_to_arrow(self, rows: List) -> "pyarrow.Table": """ Convert raw data rows to Arrow table. """ + if not rows: + return pyarrow.Table.from_pydict({}) columns = [] num_cols = len(rows[0]) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 933032044..22a590fe6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -21,7 +21,8 @@ except ImportError: pyarrow = None -from databricks.sql import OperationalError, exc +from databricks.sql import OperationalError +from databricks.sql.exc import ProgrammingError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( TRowSet, @@ -148,9 +149,7 @@ def build_queue( raise NotImplementedError( "EXTERNAL_LINKS disposition is not implemented for SEA backend" ) - else: - # Empty result set - return JsonQueue([]) + raise ProgrammingError("No result data or external links found") class JsonQueue(ResultSetQueue): diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 92b94402c..4a4dee8f5 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -135,7 +135,7 @@ def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): """Test building a queue with empty data.""" # Create a ResultData object with no data - result_data = ResultData(data=None, external_links=None, row_count=0) + result_data = ResultData(data=[], external_links=None, row_count=0) # Build the queue queue = SeaResultSetQueueFactory.build_queue( diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index f8a36657a..775b42d13 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -74,10 +74,10 @@ def result_set_with_data( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, result_data=result_data, manifest=None, + buffer_size_bytes=1000, + arraysize=100, ) result_set.results = JsonQueue(sample_data) @@ -96,6 +96,7 @@ def test_init_with_execute_response( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -115,6 +116,7 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -135,6 +137,7 @@ def test_close_when_already_closed_server_side( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -157,6 +160,7 @@ def test_close_when_connection_closed( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) @@ -301,39 +305,40 @@ def test_fetchmany_arrow_not_implemented( self, mock_connection, mock_sea_client, execute_response, sample_data ): """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" - # Create a result set without JSON data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchmany_arrow only supported for JSON data" + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - result_set.fetchmany_arrow(10) + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=None, external_links=[]), + buffer_size_bytes=1000, + arraysize=100, + ) def test_fetchall_arrow_not_implemented( self, mock_connection, mock_sea_client, execute_response, sample_data ): """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" - # Create a result set without JSON data - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - buffer_size_bytes=1000, - arraysize=100, - ) - # Test that NotImplementedError is raised with pytest.raises( - NotImplementedError, match="fetchall_arrow only supported for JSON data" + NotImplementedError, + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", ): - result_set.fetchall_arrow() + # Create a result set without JSON data + result_set = SeaResultSet( + connection=mock_connection, + execute_response=execute_response, + sea_client=mock_sea_client, + result_data=ResultData(data=None, external_links=[]), + buffer_size_bytes=1000, + arraysize=100, + ) def test_is_staging_operation( self, mock_connection, mock_sea_client, execute_response @@ -347,6 +352,7 @@ def test_is_staging_operation( connection=mock_connection, execute_response=execute_response, sea_client=mock_sea_client, + result_data=ResultData(data=[]), buffer_size_bytes=1000, arraysize=100, ) From bb948a0380a82ad9b09bf618138bf961f0be99c3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 09:58:29 +0000 Subject: [PATCH 181/204] return empty JsonQueue in case of empty response test ref: test_create_table_will_return_empty_result_set Signed-off-by: varun-edachali-dbx --- src/databricks/sql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 22a590fe6..051a6995d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -149,7 +149,7 @@ def build_queue( raise NotImplementedError( "EXTERNAL_LINKS disposition is not implemented for SEA backend" ) - raise ProgrammingError("No result data or external links found") + return JsonQueue([]) class JsonQueue(ResultSetQueue): From 921a8c17a6e70fa7dc1e63cc7e879252766bdce7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 10:00:59 +0000 Subject: [PATCH 182/204] remove string literals around SeaDatabricksClient declaration Signed-off-by: varun-edachali-dbx --- src/databricks/sql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 051a6995d..fcf39df33 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -120,7 +120,7 @@ def build_queue( description: Optional[List[Tuple[Any, ...]]] = None, schema_bytes: Optional[bytes] = None, max_download_threads: Optional[int] = None, - sea_client: Optional["SeaDatabricksClient"] = None, + sea_client: Optional[SeaDatabricksClient] = None, lz4_compressed: bool = False, ) -> ResultSetQueue: """ From cc5203df06e1991054e6601e51eef78684d66d06 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 01:34:26 +0000 Subject: [PATCH 183/204] move conversion module into dedicated utils Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/{ => utils}/conversion.py | 0 src/databricks/sql/result_set.py | 2 +- tests/unit/test_sea_conversion.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/databricks/sql/backend/sea/{ => utils}/conversion.py (100%) diff --git a/src/databricks/sql/backend/sea/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py similarity index 100% rename from src/databricks/sql/backend/sea/conversion.py rename to src/databricks/sql/backend/sea/utils/conversion.py diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a4814db57..d95178148 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -8,7 +8,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest -from databricks.sql.backend.sea.conversion import SqlTypeConverter +from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter try: import pyarrow diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py index 738889975..5b25123bd 100644 --- a/tests/unit/test_sea_conversion.py +++ b/tests/unit/test_sea_conversion.py @@ -9,7 +9,7 @@ import decimal from unittest.mock import Mock, patch -from databricks.sql.backend.sea.conversion import SqlType, SqlTypeConverter +from databricks.sql.backend.sea.utils.conversion import SqlType, SqlTypeConverter class TestSqlTypeConverter: From cc86832f7cf05203b6dab429de94823747bf6408 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 01:46:12 +0000 Subject: [PATCH 184/204] clean up _convert_decimal, introduce scale and precision as kwargs Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/conversion.py | 52 ++++++++++++++----- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index a3edd6dcc..4f8d6d6b5 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -14,6 +14,40 @@ logger = logging.getLogger(__name__) +def _convert_decimal( + value: str, precision: Optional[int] = None, scale: Optional[int] = None +) -> decimal.Decimal: + """ + Convert a string value to a decimal with optional precision and scale. + + Args: + value: The string value to convert + precision: Optional precision (total number of significant digits) for the decimal + scale: Optional scale (number of decimal places) for the decimal + + Returns: + A decimal.Decimal object with appropriate precision and scale + """ + + # First create the decimal from the string value + result = decimal.Decimal(value) + + # Apply scale (quantize to specific number of decimal places) if specified + quantizer = None + if scale is not None: + quantizer = decimal.Decimal(f'0.{"0" * scale}') + + # Apply precision (total number of significant digits) if specified + context = None + if precision is not None: + context = decimal.Context(prec=precision) + + if quantizer is not None: + result = result.quantize(quantizer, context=context) + + return result + + class SqlType: """ SQL type constants @@ -72,13 +106,7 @@ class SqlTypeConverter: SqlType.LONG: lambda v: int(v), SqlType.FLOAT: lambda v: float(v), SqlType.DOUBLE: lambda v: float(v), - SqlType.DECIMAL: lambda v, p=None, s=None: ( - decimal.Decimal(v).quantize( - decimal.Decimal(f'0.{"0" * s}'), context=decimal.Context(prec=p) - ) - if p is not None and s is not None - else decimal.Decimal(v) - ), + SqlType.DECIMAL: _convert_decimal, # Boolean type SqlType.BOOLEAN: lambda v: v.lower() in ("true", "t", "1", "yes", "y"), # Date/Time types @@ -98,10 +126,9 @@ class SqlTypeConverter: @staticmethod def convert_value( - value: Any, + value: str, sql_type: str, - precision: Optional[int] = None, - scale: Optional[int] = None, + **kwargs, ) -> Any: """ Convert a string value to the appropriate Python type based on SQL type. @@ -109,8 +136,7 @@ def convert_value( Args: value: The string value to convert sql_type: The SQL type (e.g., 'int', 'decimal') - precision: Optional precision for decimal types - scale: Optional scale for decimal types + **kwargs: Additional keyword arguments for the conversion function Returns: The converted value in the appropriate Python type @@ -127,6 +153,8 @@ def convert_value( converter_func = SqlTypeConverter.TYPE_MAPPING[sql_type] try: if sql_type == SqlType.DECIMAL: + precision = kwargs.get("precision", None) + scale = kwargs.get("scale", None) return converter_func(value, precision, scale) else: return converter_func(value) From 3f1fd937451963c3880c2b80e5f6396be0c238a5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 01:50:50 +0000 Subject: [PATCH 185/204] use stronger typing in convert_value (object instead of Any) Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/conversion.py | 7 ++----- tests/unit/test_sea_conversion.py | 6 ------ 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 4f8d6d6b5..b2de97f5d 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -9,7 +9,7 @@ import decimal import logging from dateutil import parser -from typing import Any, Callable, Dict, Optional, Union +from typing import Callable, Dict, Optional logger = logging.getLogger(__name__) @@ -129,7 +129,7 @@ def convert_value( value: str, sql_type: str, **kwargs, - ) -> Any: + ) -> object: """ Convert a string value to the appropriate Python type based on SQL type. @@ -142,9 +142,6 @@ def convert_value( The converted value in the appropriate Python type """ - if value is None: - return None - sql_type = sql_type.lower().strip() if sql_type not in SqlTypeConverter.TYPE_MAPPING: diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py index 5b25123bd..13970c5db 100644 --- a/tests/unit/test_sea_conversion.py +++ b/tests/unit/test_sea_conversion.py @@ -15,12 +15,6 @@ class TestSqlTypeConverter: """Test suite for the SqlTypeConverter class.""" - def test_convert_value_null(self): - """Test converting null values.""" - assert SqlTypeConverter.convert_value(None, SqlType.INT) is None - assert SqlTypeConverter.convert_value(None, SqlType.STRING) is None - assert SqlTypeConverter.convert_value(None, SqlType.BOOLEAN) is None - def test_convert_numeric_types(self): """Test converting numeric types.""" # Test integer types From 0bdf8f98677a6f253e5b3d8c0f5fc7d8c47a1d70 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 01:57:04 +0000 Subject: [PATCH 186/204] make Manifest mandatory Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- tests/unit/test_sea_result_set.py | 23 +++++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index d95178148..eaf19ec76 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -456,7 +456,7 @@ def __init__( execute_response: ExecuteResponse, sea_client: SeaDatabricksClient, result_data: ResultData, - manifest: Optional[ResultManifest] = None, + manifest: ResultManifest, buffer_size_bytes: int = 104857600, arraysize: int = 10000, ): diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 775b42d13..25c67df4c 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -6,11 +6,11 @@ """ import pytest -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import Mock from databricks.sql.result_set import SeaResultSet, Row from databricks.sql.utils import JsonQueue -from databricks.sql.backend.types import CommandId, CommandState, BackendType +from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.backend.sea.models.base import ResultData, ResultManifest @@ -88,6 +88,18 @@ def json_queue(self, sample_data): """Create a JsonQueue with sample data.""" return JsonQueue(sample_data) + def empty_manifest(self): + """Create an empty manifest.""" + return ResultManifest( + format="JSON_ARRAY", + schema={}, + total_row_count=0, + total_byte_count=0, + total_chunk_count=0, + truncated=False, + is_volume_operation=False, + ) + def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -97,6 +109,7 @@ def test_init_with_execute_response( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) @@ -117,6 +130,7 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) @@ -138,6 +152,7 @@ def test_close_when_already_closed_server_side( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) @@ -161,6 +176,7 @@ def test_close_when_connection_closed( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) @@ -317,6 +333,7 @@ def test_fetchmany_arrow_not_implemented( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) @@ -336,6 +353,7 @@ def test_fetchall_arrow_not_implemented( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) @@ -353,6 +371,7 @@ def test_is_staging_operation( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), + manifest=self.empty_manifest(), buffer_size_bytes=1000, arraysize=100, ) From 28b4d7b5ec7a0adfee54734dabab8c21dac37a60 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 02:05:30 +0000 Subject: [PATCH 187/204] mandatory Manifest, clean up statement_id typing Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/filters.py | 4 ++++ src/databricks/sql/backend/types.py | 6 +++--- src/databricks/sql/result_set.py | 10 ++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index f3bf4669a..c50c20a43 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -73,11 +73,15 @@ def _filter_sea_result_set( from databricks.sql.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data + manifest = result_set.manifest + manifest.total_row_count = len(filtered_rows) + filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, sea_client=cast(SeaDatabricksClient, result_set.backend), result_data=result_data, + manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, arraysize=result_set.arraysize, ) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 93bd7d525..55808b724 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -369,7 +369,7 @@ def from_sea_statement_id(cls, statement_id: str): return cls(BackendType.SEA, statement_id) - def to_thrift_handle(self): + def to_thrift_handle(self) -> Optional[ttypes.TOperationHandle]: """ Convert this CommandId to a Thrift TOperationHandle. @@ -390,7 +390,7 @@ def to_thrift_handle(self): modifiedRowCount=self.modified_row_count, ) - def to_sea_statement_id(self): + def to_sea_statement_id(self) -> Optional[str]: """ Get the SEA statement ID string. @@ -401,7 +401,7 @@ def to_sea_statement_id(self): if self.backend_type != BackendType.SEA: return None - return self.guid + return str(self.guid) def to_hex_guid(self) -> str: """ diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index eaf19ec76..ff702218a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -473,10 +473,16 @@ def __init__( manifest: Manifest from SEA response """ + self.manifest = manifest + + statement_id = execute_response.command_id.to_sea_statement_id() + if statement_id is None: + raise ValueError("Command ID is not a SEA statement ID") + results_queue = SeaResultSetQueueFactory.build_queue( result_data, - manifest, - str(execute_response.command_id.to_sea_statement_id()), + self.manifest, + statement_id, description=execute_response.description, max_download_threads=sea_client.max_download_threads, sea_client=sea_client, From 245aa773adc1f31c271f1848f29804b01fe13c59 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 02:18:14 +0000 Subject: [PATCH 188/204] stronger typing for fetch*_json Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ff702218a..57f935aad 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -505,7 +505,7 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", ) - def _convert_json_to_arrow(self, rows: List) -> "pyarrow.Table": + def _convert_json_to_arrow(self, rows: List[List]) -> "pyarrow.Table": """ Convert raw data rows to Arrow table. """ @@ -519,7 +519,7 @@ def _convert_json_to_arrow(self, rows: List) -> "pyarrow.Table": names = [col[0] for col in self.description] return pyarrow.Table.from_arrays(columns, names=names) - def _convert_json_types(self, rows: List) -> List: + def _convert_json_types(self, rows: List[List]) -> List[List]: """ Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. @@ -554,7 +554,7 @@ def _convert_json_types(self, rows: List) -> List: return converted_rows - def _create_json_table(self, rows: List) -> List[Row]: + def _create_json_table(self, rows: List[List]) -> List[Row]: """ Convert raw data rows to Row objects with named columns based on description. Also converts string values to appropriate Python types based on column metadata. @@ -573,7 +573,7 @@ def _create_json_table(self, rows: List) -> List[Row]: return [ResultRow(*row) for row in rows] - def fetchmany_json(self, size: int) -> List: + def fetchmany_json(self, size: int) -> List[List]: """ Fetch the next set of rows as a columnar table. @@ -595,7 +595,7 @@ def fetchmany_json(self, size: int) -> List: return results - def fetchall_json(self) -> List: + def fetchall_json(self) -> List[List]: """ Fetch all remaining rows as a columnar table. From 7d21ad1107e17a1b136f210321ff484a7a0996bc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 02:31:45 +0000 Subject: [PATCH 189/204] make description non Optional, correct docstring, optimize col conversion Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 9 +++------ src/databricks/sql/utils.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 57f935aad..416e6e03a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -512,17 +512,14 @@ def _convert_json_to_arrow(self, rows: List[List]) -> "pyarrow.Table": if not rows: return pyarrow.Table.from_pydict({}) - columns = [] - num_cols = len(rows[0]) - for i in range(num_cols): - columns.append([row[i] for row in rows]) + # Transpose rows to columns efficiently using zip + columns = list(map(list, zip(*rows))) names = [col[0] for col in self.description] return pyarrow.Table.from_arrays(columns, names=names) def _convert_json_types(self, rows: List[List]) -> List[List]: """ - Convert raw data rows to Row objects with named columns based on description. - Also converts string values to appropriate Python types based on column metadata. + Convert string values to appropriate Python types based on column metadata. """ if not self.description or not rows: diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index fcf39df33..6f56be369 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -274,7 +274,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ): """ A queue-like wrapper over CloudFetch arrow batches. From 4d10dcc4f382b12a4f9c9f5e05e8e7e1831c895c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 03:16:49 +0000 Subject: [PATCH 190/204] fix type issues Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 ++++++++ src/databricks/sql/result_set.py | 6 ------ src/databricks/sql/utils.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 2ed248c3d..a0cb81710 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -511,6 +511,8 @@ def cancel_command(self, command_id: CommandId) -> None: raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ProgrammingError("Not a valid SEA command ID") request = CancelStatementRequest(statement_id=sea_statement_id) self.http_client._make_request( @@ -534,6 +536,8 @@ def close_command(self, command_id: CommandId) -> None: raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ProgrammingError("Not a valid SEA command ID") request = CloseStatementRequest(statement_id=sea_statement_id) self.http_client._make_request( @@ -560,6 +564,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState: raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ProgrammingError("Not a valid SEA command ID") request = GetStatementRequest(statement_id=sea_statement_id) response_data = self.http_client._make_request( @@ -595,6 +601,8 @@ def get_execution_result( raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() + if sea_statement_id is None: + raise ProgrammingError("Not a valid SEA command ID") # Create the request model request = GetStatementRequest(statement_id=sea_statement_id) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 416e6e03a..1ed0cf57c 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -522,9 +522,6 @@ def _convert_json_types(self, rows: List[List]) -> List[List]: Convert string values to appropriate Python types based on column metadata. """ - if not self.description or not rows: - return rows - # JSON + INLINE gives us string values, so we convert them to appropriate # types based on column metadata converted_rows = [] @@ -562,9 +559,6 @@ def _create_json_table(self, rows: List[List]) -> List[Row]: List of Row objects with named columns and converted values """ - if not self.description or not rows: - return rows - ResultRow = Row(*[col[0] for col in self.description]) rows = self._convert_json_types(rows) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 6f56be369..fcf39df33 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -274,7 +274,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: List[Tuple] = [], + description: Optional[List[Tuple]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. From 14c5625a153fae94fba98237cb24529f4a7f7e87 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 03:24:28 +0000 Subject: [PATCH 191/204] make description mandatory, not Optional Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 9 +++------ src/databricks/sql/backend/types.py | 2 +- src/databricks/sql/result_set.py | 4 ++-- src/databricks/sql/utils.py | 6 +++--- tests/unit/test_client.py | 1 + tests/unit/test_sea_backend.py | 12 ------------ 6 files changed, 10 insertions(+), 24 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a0cb81710..3d398d90a 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -290,7 +290,7 @@ def get_allowed_session_configurations() -> List[str]: def _extract_description_from_manifest( self, manifest: ResultManifest - ) -> Optional[List]: + ) -> List[Tuple]: """ Extract column description from a manifest object, in the format defined by the spec: https://peps.python.org/pep-0249/#description @@ -299,15 +299,12 @@ def _extract_description_from_manifest( manifest: The ResultManifest object containing schema information Returns: - Optional[List]: A list of column tuples or None if no columns are found + List[Tuple]: A list of column tuples """ schema_data = manifest.schema columns_data = schema_data.get("columns", []) - if not columns_data: - return None - columns = [] for col_data in columns_data: # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) @@ -323,7 +320,7 @@ def _extract_description_from_manifest( ) ) - return columns if columns else None + return columns def _results_message_to_execute_response( self, response: GetStatementResponse diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 55808b724..5411af74f 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -423,7 +423,7 @@ class ExecuteResponse: command_id: CommandId status: CommandState - description: Optional[List[Tuple]] = None + description: List[Tuple] has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 1ed0cf57c..a9ddc6581 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Tuple import logging import pandas @@ -50,7 +50,7 @@ def __init__( has_been_closed_server_side: bool = False, is_direct_results: bool = False, results_queue=None, - description=None, + description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, arrow_schema_bytes: Optional[bytes] = None, diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index fcf39df33..c2eec6a3d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -61,7 +61,7 @@ def build_queue( max_download_threads: int, ssl_options: SSLOptions, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -117,7 +117,7 @@ def build_queue( sea_result_data: ResultData, manifest: Optional[ResultManifest], statement_id: str, - description: Optional[List[Tuple[Any, ...]]] = None, + description: List[Tuple] = [], schema_bytes: Optional[bytes] = None, max_download_threads: Optional[int] = None, sea_client: Optional[SeaDatabricksClient] = None, @@ -274,7 +274,7 @@ def __init__( start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, - description: Optional[List[Tuple]] = None, + description: List[Tuple] = [], ): """ A queue-like wrapper over CloudFetch arrow batches. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0eda7767c..5ffdea9f0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -100,6 +100,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): ) mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False + mock_execute_response.description = [] # Mock the backend that will be used by the real ThriftResultSet mock_backend = Mock(spec=ThriftDatabricksClient) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc6768d2b..1206457d7 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -621,18 +621,6 @@ def test_utility_methods(self, sea_client): assert description[1][1] == "INT" # type_code assert description[1][6] is False # null_ok - # Test _extract_description_from_manifest with empty columns - empty_manifest = MagicMock() - empty_manifest.schema = {"columns": []} - assert sea_client._extract_description_from_manifest(empty_manifest) is None - - # Test _extract_description_from_manifest with no columns key - no_columns_manifest = MagicMock() - no_columns_manifest.schema = {} - assert ( - sea_client._extract_description_from_manifest(no_columns_manifest) is None - ) - def test_results_message_to_execute_response_is_staging_operation(self, sea_client): """Test that is_staging_operation is correctly set from manifest.is_volume_operation.""" # Test when is_volume_operation is True From 31a0e52621d9df0a7e043e9f18cde188a9f457f0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 04:06:48 +0000 Subject: [PATCH 192/204] n_valid_rows -> num_rows Signed-off-by: varun-edachali-dbx --- src/databricks/sql/utils.py | 6 +++--- tests/unit/test_sea_queue.py | 8 ++++---- tests/unit/test_sea_result_set.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c2eec6a3d..da7cbe483 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -118,7 +118,6 @@ def build_queue( manifest: Optional[ResultManifest], statement_id: str, description: List[Tuple] = [], - schema_bytes: Optional[bytes] = None, max_download_threads: Optional[int] = None, sea_client: Optional[SeaDatabricksClient] = None, lz4_compressed: bool = False, @@ -141,6 +140,7 @@ def build_queue( ResultSetQueue: The appropriate queue for the result data """ + print(sea_result_data) if sea_result_data.data is not None: # INLINE disposition with JSON_ARRAY format return JsonQueue(sea_result_data.data) @@ -159,11 +159,11 @@ def __init__(self, data_array): """Initialize with JSON array data.""" self.data_array = data_array self.cur_row_index = 0 - self.n_valid_rows = len(data_array) + self.num_rows = len(data_array) def next_n_rows(self, num_rows): """Get the next n rows from the data array.""" - length = min(num_rows, self.n_valid_rows - self.cur_row_index) + length = min(num_rows, self.num_rows - self.cur_row_index) slice = self.data_array[self.cur_row_index : self.cur_row_index + length] self.cur_row_index += length return slice diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 4a4dee8f5..453b0f5bb 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -30,7 +30,7 @@ def test_init(self, sample_data): queue = JsonQueue(sample_data) assert queue.data_array == sample_data assert queue.cur_row_index == 0 - assert queue.n_valid_rows == len(sample_data) + assert queue.num_rows == len(sample_data) def test_next_n_rows_partial(self, sample_data): """Test fetching a subset of rows.""" @@ -82,7 +82,7 @@ def test_empty_data(self): assert queue.next_n_rows(10) == [] assert queue.remaining_rows() == [] assert queue.cur_row_index == 0 - assert queue.n_valid_rows == 0 + assert queue.num_rows == 0 class TestSeaResultSetQueueFactory: @@ -130,7 +130,7 @@ def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): # Verify the queue is a JsonQueue with the correct data assert isinstance(queue, JsonQueue) assert queue.data_array == data - assert queue.n_valid_rows == len(data) + assert queue.num_rows == len(data) def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): """Test building a queue with empty data.""" @@ -149,7 +149,7 @@ def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): # Verify the queue is a JsonQueue with empty data assert isinstance(queue, JsonQueue) assert queue.data_array == [] - assert queue.n_valid_rows == 0 + assert queue.num_rows == 0 def test_build_queue_with_external_links(self, mock_sea_client, mock_description): """Test building a queue with external links raises NotImplementedError.""" diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 25c67df4c..4ca231e49 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -194,7 +194,7 @@ def test_init_with_result_data(self, result_set_with_data, sample_data): # Verify the results queue was created correctly assert isinstance(result_set_with_data.results, JsonQueue) assert result_set_with_data.results.data_array == sample_data - assert result_set_with_data.results.n_valid_rows == len(sample_data) + assert result_set_with_data.results.num_rows == len(sample_data) def test_convert_json_types(self, result_set_with_data, sample_data): """Test the _convert_json_types method.""" From e86e755d9fd589dabb742a0833fef9a4edb9667a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 04:07:10 +0000 Subject: [PATCH 193/204] remove excess print statement Signed-off-by: varun-edachali-dbx --- src/databricks/sql/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index da7cbe483..165c1ed2e 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -140,7 +140,6 @@ def build_queue( ResultSetQueue: The appropriate queue for the result data """ - print(sea_result_data) if sea_result_data.data is not None: # INLINE disposition with JSON_ARRAY format return JsonQueue(sea_result_data.data) From 7035098e814b1a00ce0526c6889a6a0fbba278e6 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 04:12:01 +0000 Subject: [PATCH 194/204] remove empty bytes in SeaResultSet for arrow_schema_bytes Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a9ddc6581..ff0c2f806 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -502,7 +502,7 @@ def __init__( description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + arrow_schema_bytes=execute_response.arrow_schema_bytes, ) def _convert_json_to_arrow(self, rows: List[List]) -> "pyarrow.Table": From 4566cb1e04cee6c0b1b20ec2f97ff9ec897a0d77 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 04:21:34 +0000 Subject: [PATCH 195/204] move SeaResultSetQueueFactory and JsonQueue into separate SEA module Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/queue.py | 71 +++++++++++++++++++++++++ src/databricks/sql/result_set.py | 3 +- src/databricks/sql/utils.py | 63 ---------------------- tests/unit/test_sea_queue.py | 2 +- tests/unit/test_sea_result_set.py | 2 +- 5 files changed, 74 insertions(+), 67 deletions(-) create mode 100644 src/databricks/sql/backend/sea/queue.py diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py new file mode 100644 index 000000000..9fca829d1 --- /dev/null +++ b/src/databricks/sql/backend/sea/queue.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from abc import ABC +from typing import List, Optional, Tuple + +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.utils import ResultSetQueue + + +class SeaResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + sea_result_data: ResultData, + manifest: Optional[ResultManifest], + statement_id: str, + description: List[Tuple] = [], + max_download_threads: Optional[int] = None, + sea_client: Optional[SeaDatabricksClient] = None, + lz4_compressed: bool = False, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue for SEA backend. + + Args: + sea_result_data (ResultData): Result data from SEA response + manifest (ResultManifest): Manifest from SEA response + statement_id (str): Statement ID for the query + description (List[List[Any]]): Column descriptions + schema_bytes (bytes): Arrow schema bytes + max_download_threads (int): Maximum number of download threads + ssl_options (SSLOptions): SSL options for downloads + sea_client (SeaDatabricksClient): SEA client for fetching additional links + lz4_compressed (bool): Whether the data is LZ4 compressed + + Returns: + ResultSetQueue: The appropriate queue for the result data + """ + + if sea_result_data.data is not None: + # INLINE disposition with JSON_ARRAY format + return JsonQueue(sea_result_data.data) + elif sea_result_data.external_links is not None: + # EXTERNAL_LINKS disposition + raise NotImplementedError( + "EXTERNAL_LINKS disposition is not implemented for SEA backend" + ) + return JsonQueue([]) + + +class JsonQueue(ResultSetQueue): + """Queue implementation for JSON_ARRAY format data.""" + + def __init__(self, data_array): + """Initialize with JSON array data.""" + self.data_array = data_array + self.cur_row_index = 0 + self.num_rows = len(data_array) + + def next_n_rows(self, num_rows): + """Get the next n rows from the data array.""" + length = min(num_rows, self.num_rows - self.cur_row_index) + slice = self.data_array[self.cur_row_index : self.cur_row_index + length] + self.cur_row_index += length + return slice + + def remaining_rows(self): + """Get all remaining rows from the data array.""" + slice = self.data_array[self.cur_row_index :] + self.cur_row_index += len(slice) + return slice diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index ff0c2f806..01c05ee6a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -24,9 +24,8 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, - JsonQueue, - SeaResultSetQueueFactory, ) +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 165c1ed2e..35c7bce4d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -111,69 +111,6 @@ def build_queue( raise AssertionError("Row set type is not valid") -class SeaResultSetQueueFactory(ABC): - @staticmethod - def build_queue( - sea_result_data: ResultData, - manifest: Optional[ResultManifest], - statement_id: str, - description: List[Tuple] = [], - max_download_threads: Optional[int] = None, - sea_client: Optional[SeaDatabricksClient] = None, - lz4_compressed: bool = False, - ) -> ResultSetQueue: - """ - Factory method to build a result set queue for SEA backend. - - Args: - sea_result_data (ResultData): Result data from SEA response - manifest (ResultManifest): Manifest from SEA response - statement_id (str): Statement ID for the query - description (List[List[Any]]): Column descriptions - schema_bytes (bytes): Arrow schema bytes - max_download_threads (int): Maximum number of download threads - ssl_options (SSLOptions): SSL options for downloads - sea_client (SeaDatabricksClient): SEA client for fetching additional links - lz4_compressed (bool): Whether the data is LZ4 compressed - - Returns: - ResultSetQueue: The appropriate queue for the result data - """ - - if sea_result_data.data is not None: - # INLINE disposition with JSON_ARRAY format - return JsonQueue(sea_result_data.data) - elif sea_result_data.external_links is not None: - # EXTERNAL_LINKS disposition - raise NotImplementedError( - "EXTERNAL_LINKS disposition is not implemented for SEA backend" - ) - return JsonQueue([]) - - -class JsonQueue(ResultSetQueue): - """Queue implementation for JSON_ARRAY format data.""" - - def __init__(self, data_array): - """Initialize with JSON array data.""" - self.data_array = data_array - self.cur_row_index = 0 - self.num_rows = len(data_array) - - def next_n_rows(self, num_rows): - """Get the next n rows from the data array.""" - length = min(num_rows, self.num_rows - self.cur_row_index) - slice = self.data_array[self.cur_row_index : self.cur_row_index + length] - self.cur_row_index += length - return slice - - def remaining_rows(self): - """Get all remaining rows from the data array.""" - slice = self.data_array[self.cur_row_index :] - self.cur_row_index += len(slice) - return slice - - class ColumnTable: def __init__(self, column_table, column_names): self.column_table = column_table diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 453b0f5bb..f06cd6ed5 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -7,7 +7,7 @@ import pytest from unittest.mock import Mock, MagicMock, patch -from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.sea.models.base import ResultData, ResultManifest diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 4ca231e49..89ccf9b6c 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -9,7 +9,7 @@ from unittest.mock import Mock from databricks.sql.result_set import SeaResultSet, Row -from databricks.sql.utils import JsonQueue +from databricks.sql.backend.sea.queue import JsonQueue from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.backend.sea.models.base import ResultData, ResultManifest From 72a5cd36016ead3869f5222e9fd3ab941ef0d0c1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 05:01:46 +0000 Subject: [PATCH 196/204] move sea result set into backend/sea package Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 4 +- src/databricks/sql/backend/sea/result_set.py | 259 ++++++++++++++++++ .../sql/backend/sea/utils/filters.py | 4 +- src/databricks/sql/result_set.py | 247 ----------------- tests/unit/test_filters.py | 4 +- tests/unit/test_sea_backend.py | 2 +- tests/unit/test_sea_result_set.py | 11 +- 7 files changed, 271 insertions(+), 260 deletions(-) create mode 100644 src/databricks/sql/backend/sea/result_set.py diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 3d398d90a..247c2ed91 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -613,7 +613,7 @@ def get_execution_result( response = GetStatementResponse.from_dict(response_data) # Create and return a SeaResultSet - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet execute_response = self._results_message_to_execute_response(response) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py new file mode 100644 index 000000000..8ce067850 --- /dev/null +++ b/src/databricks/sql/backend/sea/result_set.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from typing import List, Optional, TYPE_CHECKING + +import logging + +from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter + +try: + import pyarrow +except ImportError: + pyarrow = None + +if TYPE_CHECKING: + from databricks.sql.client import Connection +from databricks.sql.types import Row +from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.result_set import ResultSet + +logger = logging.getLogger(__name__) + + +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection: Connection, + execute_response: ExecuteResponse, + sea_client: SeaDatabricksClient, + result_data: ResultData, + manifest: ResultManifest, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + ): + """ + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + execute_response: Response from the execute command + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + result_data: Result data from SEA response + manifest: Manifest from SEA response + """ + + self.manifest = manifest + + statement_id = execute_response.command_id.to_sea_statement_id() + if statement_id is None: + raise ValueError("Command ID is not a SEA statement ID") + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + description=execute_response.description, + max_download_threads=sea_client.max_download_threads, + sea_client=sea_client, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + results_queue=results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + lz4_compressed=execute_response.lz4_compressed, + arrow_schema_bytes=execute_response.arrow_schema_bytes, + ) + + def _convert_json_types(self, row: List) -> List: + """ + Convert string values to appropriate Python types based on column metadata. + """ + + # JSON + INLINE gives us string values, so we convert them to appropriate + # types based on column metadata + converted_row = [] + + for i, value in enumerate(row): + column_type = self.description[i][1] + precision = self.description[i][4] + scale = self.description[i][5] + + try: + converted_value = SqlTypeConverter.convert_value( + value, column_type, precision=precision, scale=scale + ) + converted_row.append(converted_value) + except Exception as e: + logger.warning( + f"Error converting value '{value}' to {column_type}: {e}" + ) + converted_row.append(value) + + return converted_row + + def _convert_json_to_arrow_table(self, rows: List[List]) -> "pyarrow.Table": + """ + Convert raw data rows to Arrow table. + """ + if not rows: + return pyarrow.Table.from_pydict({}) + + # create a generator for row conversion + converted_rows_iter = (self._convert_json_types(row) for row in rows) + cols = list(map(list, zip(*converted_rows_iter))) + + names = [col[0] for col in self.description] + return pyarrow.Table.from_arrays(cols, names=names) + + def _create_json_table(self, rows: List[List]) -> List[Row]: + """ + Convert raw data rows to Row objects with named columns based on description. + Also converts string values to appropriate Python types based on column metadata. + + Args: + rows: List of raw data rows + Returns: + List of Row objects with named columns and converted values + """ + + ResultRow = Row(*[col[0] for col in self.description]) + return [ResultRow(*self._convert_json_types(row)) for row in rows] + + def fetchmany_json(self, size: int) -> List[List]: + """ + Fetch the next set of rows as a columnar table. + + Args: + size: Number of rows to fetch + + Returns: + Columnar table containing the fetched rows + + Raises: + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + results = self.results.next_n_rows(size) + self._next_row_index += len(results) + + return results + + def fetchall_json(self) -> List[List]: + """ + Fetch all remaining rows as a columnar table. + + Returns: + Columnar table containing all remaining rows + """ + + results = self.results.remaining_rows() + self._next_row_index += len(results) + + return results + + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + """ + Fetch the next set of rows as an Arrow table. + + Args: + size: Number of rows to fetch + + Returns: + PyArrow Table containing the fetched rows + + Raises: + ImportError: If PyArrow is not installed + ValueError: If size is negative + """ + + if size < 0: + raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchmany_arrow only supported for JSON data") + + results = self._convert_json_to_arrow_table(self.results.next_n_rows(size)) + self._next_row_index += results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": + """ + Fetch all remaining rows as an Arrow table. + """ + + if not isinstance(self.results, JsonQueue): + raise NotImplementedError("fetchall_arrow only supported for JSON data") + + results = self._convert_json_to_arrow_table(self.results.remaining_rows()) + self._next_row_index += results.num_rows + + return results + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + + Returns: + A single Row object or None if no more rows are available + """ + + if isinstance(self.results, JsonQueue): + res = self._create_json_table(self.fetchmany_json(1)) + else: + raise NotImplementedError("fetchone only supported for JSON data") + + return res[0] if res else None + + def fetchmany(self, size: int) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + Args: + size: Number of rows to fetch (defaults to arraysize if None) + + Returns: + List of Row objects + + Raises: + ValueError: If size is negative + """ + + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchmany_json(size)) + else: + raise NotImplementedError("fetchmany only supported for JSON data") + + def fetchall(self) -> List[Row]: + """ + Fetch all remaining rows of a query result, returning them as a list of rows. + + Returns: + List of Row objects containing all remaining rows + """ + + if isinstance(self.results, JsonQueue): + return self._create_json_table(self.fetchall_json()) + else: + raise NotImplementedError("fetchall only supported for JSON data") diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index c50c20a43..ef6c91d7d 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -17,7 +17,7 @@ ) if TYPE_CHECKING: - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.types import ExecuteResponse @@ -70,7 +70,7 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) from databricks.sql.backend.sea.backend import SeaDatabricksClient - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data manifest = result_set.manifest diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 01c05ee6a..8934d0d56 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -6,10 +6,6 @@ import logging import pandas -from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.sea.models.base import ResultData, ResultManifest -from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter - try: import pyarrow except ImportError: @@ -25,7 +21,6 @@ ColumnTable, ColumnQueue, ) -from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -444,245 +439,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - - -class SeaResultSet(ResultSet): - """ResultSet implementation for SEA backend.""" - - def __init__( - self, - connection: Connection, - execute_response: ExecuteResponse, - sea_client: SeaDatabricksClient, - result_data: ResultData, - manifest: ResultManifest, - buffer_size_bytes: int = 104857600, - arraysize: int = 10000, - ): - """ - Initialize a SeaResultSet with the response from a SEA query execution. - - Args: - connection: The parent connection - execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access - buffer_size_bytes: Buffer size for fetching results - arraysize: Default number of rows to fetch - result_data: Result data from SEA response - manifest: Manifest from SEA response - """ - - self.manifest = manifest - - statement_id = execute_response.command_id.to_sea_statement_id() - if statement_id is None: - raise ValueError("Command ID is not a SEA statement ID") - - results_queue = SeaResultSetQueueFactory.build_queue( - result_data, - self.manifest, - statement_id, - description=execute_response.description, - max_download_threads=sea_client.max_download_threads, - sea_client=sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) - - # Call parent constructor with common attributes - super().__init__( - connection=connection, - backend=sea_client, - arraysize=arraysize, - buffer_size_bytes=buffer_size_bytes, - command_id=execute_response.command_id, - status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - results_queue=results_queue, - description=execute_response.description, - is_staging_operation=execute_response.is_staging_operation, - lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, - ) - - def _convert_json_to_arrow(self, rows: List[List]) -> "pyarrow.Table": - """ - Convert raw data rows to Arrow table. - """ - if not rows: - return pyarrow.Table.from_pydict({}) - - # Transpose rows to columns efficiently using zip - columns = list(map(list, zip(*rows))) - names = [col[0] for col in self.description] - return pyarrow.Table.from_arrays(columns, names=names) - - def _convert_json_types(self, rows: List[List]) -> List[List]: - """ - Convert string values to appropriate Python types based on column metadata. - """ - - # JSON + INLINE gives us string values, so we convert them to appropriate - # types based on column metadata - converted_rows = [] - for row in rows: - converted_row = [] - - for i, value in enumerate(row): - column_type = self.description[i][1] - precision = self.description[i][4] - scale = self.description[i][5] - - try: - converted_value = SqlTypeConverter.convert_value( - value, column_type, precision=precision, scale=scale - ) - converted_row.append(converted_value) - except Exception as e: - logger.warning( - f"Error converting value '{value}' to {column_type}: {e}" - ) - converted_row.append(value) - - converted_rows.append(converted_row) - - return converted_rows - - def _create_json_table(self, rows: List[List]) -> List[Row]: - """ - Convert raw data rows to Row objects with named columns based on description. - Also converts string values to appropriate Python types based on column metadata. - - Args: - rows: List of raw data rows - Returns: - List of Row objects with named columns and converted values - """ - - ResultRow = Row(*[col[0] for col in self.description]) - rows = self._convert_json_types(rows) - - return [ResultRow(*row) for row in rows] - - def fetchmany_json(self, size: int) -> List[List]: - """ - Fetch the next set of rows as a columnar table. - - Args: - size: Number of rows to fetch - - Returns: - Columnar table containing the fetched rows - - Raises: - ValueError: If size is negative - """ - - if size < 0: - raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - - results = self.results.next_n_rows(size) - self._next_row_index += len(results) - - return results - - def fetchall_json(self) -> List[List]: - """ - Fetch all remaining rows as a columnar table. - - Returns: - Columnar table containing all remaining rows - """ - - results = self.results.remaining_rows() - self._next_row_index += len(results) - - return results - - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": - """ - Fetch the next set of rows as an Arrow table. - - Args: - size: Number of rows to fetch - - Returns: - PyArrow Table containing the fetched rows - - Raises: - ImportError: If PyArrow is not installed - ValueError: If size is negative - """ - - if size < 0: - raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") - - if not isinstance(self.results, JsonQueue): - raise NotImplementedError("fetchmany_arrow only supported for JSON data") - - rows = self._convert_json_types(self.results.next_n_rows(size)) - results = self._convert_json_to_arrow(rows) - self._next_row_index += results.num_rows - - return results - - def fetchall_arrow(self) -> "pyarrow.Table": - """ - Fetch all remaining rows as an Arrow table. - """ - - if not isinstance(self.results, JsonQueue): - raise NotImplementedError("fetchall_arrow only supported for JSON data") - - rows = self._convert_json_types(self.results.remaining_rows()) - results = self._convert_json_to_arrow(rows) - self._next_row_index += results.num_rows - - return results - - def fetchone(self) -> Optional[Row]: - """ - Fetch the next row of a query result set, returning a single sequence, - or None when no more data is available. - - Returns: - A single Row object or None if no more rows are available - """ - - if isinstance(self.results, JsonQueue): - res = self._create_json_table(self.fetchmany_json(1)) - else: - raise NotImplementedError("fetchone only supported for JSON data") - - return res[0] if res else None - - def fetchmany(self, size: int) -> List[Row]: - """ - Fetch the next set of rows of a query result, returning a list of rows. - - Args: - size: Number of rows to fetch (defaults to arraysize if None) - - Returns: - List of Row objects - - Raises: - ValueError: If size is negative - """ - - if isinstance(self.results, JsonQueue): - return self._create_json_table(self.fetchmany_json(size)) - else: - raise NotImplementedError("fetchmany only supported for JSON data") - - def fetchall(self) -> List[Row]: - """ - Fetch all remaining rows of a query result, returning them as a list of rows. - - Returns: - List of Row objects containing all remaining rows - """ - - if isinstance(self.results, JsonQueue): - return self._create_json_table(self.fetchall_json()) - else: - raise NotImplementedError("fetchall only supported for JSON data") diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 975376e13..13dfac006 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -77,7 +77,7 @@ def test_filter_by_column_values(self): "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True ): with patch( - "databricks.sql.result_set.SeaResultSet" + "databricks.sql.backend.sea.result_set.SeaResultSet" ) as mock_sea_result_set_class: mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance @@ -104,7 +104,7 @@ def test_filter_by_column_values(self): "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True ): with patch( - "databricks.sql.result_set.SeaResultSet" + "databricks.sql.backend.sea.result_set.SeaResultSet" ) as mock_sea_result_set_class: mock_instance = MagicMock() mock_sea_result_set_class.return_value = mock_instance diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 1206457d7..23392e216 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -743,7 +743,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): def test_get_tables(self, sea_client, sea_session_id, mock_cursor): """Test the get_tables method with various parameter combinations.""" # Mock the execute_command method - from databricks.sql.result_set import SeaResultSet + from databricks.sql.backend.sea.result_set import SeaResultSet mock_result_set = Mock(spec=SeaResultSet) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 89ccf9b6c..677410708 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -8,7 +8,7 @@ import pytest from unittest.mock import Mock -from databricks.sql.result_set import SeaResultSet, Row +from databricks.sql.backend.sea.result_set import SeaResultSet, Row from databricks.sql.backend.sea.queue import JsonQueue from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.backend.sea.models.base import ResultData, ResultManifest @@ -199,13 +199,12 @@ def test_init_with_result_data(self, result_set_with_data, sample_data): def test_convert_json_types(self, result_set_with_data, sample_data): """Test the _convert_json_types method.""" # Call _convert_json_types - converted_rows = result_set_with_data._convert_json_types(sample_data) + converted_row = result_set_with_data._convert_json_types(sample_data[0]) # Verify the conversion - assert len(converted_rows) == len(sample_data) - assert converted_rows[0][0] == "value1" # string stays as string - assert converted_rows[0][1] == 1 # "1" converted to int - assert converted_rows[0][2] is True # "true" converted to boolean + assert converted_row[0] == "value1" # string stays as string + assert converted_row[1] == 1 # "1" converted to int + assert converted_row[2] is True # "true" converted to boolean def test_create_json_table(self, result_set_with_data, sample_data): """Test the _create_json_table method.""" From 511c4493e5b0e2e26b476db936f48b814aaafc3d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 05:03:28 +0000 Subject: [PATCH 197/204] improve docstrings Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 8ce067850..8fb46bc61 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -83,7 +83,7 @@ def __init__( def _convert_json_types(self, row: List) -> List: """ - Convert string values to appropriate Python types based on column metadata. + Convert string values in the row to appropriate Python types based on column metadata. """ # JSON + INLINE gives us string values, so we convert them to appropriate @@ -111,7 +111,14 @@ def _convert_json_types(self, row: List) -> List: def _convert_json_to_arrow_table(self, rows: List[List]) -> "pyarrow.Table": """ Convert raw data rows to Arrow table. + + Args: + rows: List of raw data rows + + Returns: + PyArrow Table containing the converted values """ + if not rows: return pyarrow.Table.from_pydict({}) @@ -125,7 +132,6 @@ def _convert_json_to_arrow_table(self, rows: List[List]) -> "pyarrow.Table": def _create_json_table(self, rows: List[List]) -> List[Row]: """ Convert raw data rows to Row objects with named columns based on description. - Also converts string values to appropriate Python types based on column metadata. Args: rows: List of raw data rows From c21ff5e1e3619e075e6480caa1ba195b59081f74 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 05:15:07 +0000 Subject: [PATCH 198/204] correct docstrings, ProgrammingError -> ValueError Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 18 +++++++++--------- src/databricks/sql/backend/sea/queue.py | 2 -- src/databricks/sql/backend/sea/result_set.py | 1 + tests/unit/test_sea_backend.py | 10 +++++----- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 247c2ed91..814859a31 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -251,7 +251,7 @@ def close_session(self, session_id: SessionId) -> None: logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) if session_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA session ID") + raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() request_data = DeleteSessionRequest( @@ -426,7 +426,7 @@ def execute_command( """ if session_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA session ID") + raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -505,11 +505,11 @@ def cancel_command(self, command_id: CommandId) -> None: """ if command_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() if sea_statement_id is None: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") request = CancelStatementRequest(statement_id=sea_statement_id) self.http_client._make_request( @@ -530,11 +530,11 @@ def close_command(self, command_id: CommandId) -> None: """ if command_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() if sea_statement_id is None: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") request = CloseStatementRequest(statement_id=sea_statement_id) self.http_client._make_request( @@ -562,7 +562,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: sea_statement_id = command_id.to_sea_statement_id() if sea_statement_id is None: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") request = GetStatementRequest(statement_id=sea_statement_id) response_data = self.http_client._make_request( @@ -595,11 +595,11 @@ def get_execution_result( """ if command_id.backend_type != BackendType.SEA: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() if sea_statement_id is None: - raise ProgrammingError("Not a valid SEA command ID") + raise ValueError("Not a valid SEA command ID") # Create the request model request = GetStatementRequest(statement_id=sea_statement_id) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 9fca829d1..88be098e6 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -27,9 +27,7 @@ def build_queue( manifest (ResultManifest): Manifest from SEA response statement_id (str): Statement ID for the query description (List[List[Any]]): Column descriptions - schema_bytes (bytes): Arrow schema bytes max_download_threads (int): Maximum number of download threads - ssl_options (SSLOptions): SSL options for downloads sea_client (SeaDatabricksClient): SEA client for fetching additional links lz4_compressed (bool): Whether the data is LZ4 compressed diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 8fb46bc61..844d8eb73 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from databricks.sql.client import Connection +from databricks.sql.exc import ProgrammingError from databricks.sql.types import Row from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import ExecuteResponse diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 23392e216..7eae8e5a8 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -196,7 +196,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i ) # Test close_session with invalid ID type - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) @@ -245,7 +245,7 @@ def test_command_execution_sync( assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: mock_thrift_handle = MagicMock() mock_thrift_handle.sessionId.guid = b"guid" mock_thrift_handle.sessionId.secret = b"secret" @@ -449,7 +449,7 @@ def test_command_management( ) # Test cancel_command with invalid ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.cancel_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -463,7 +463,7 @@ def test_command_management( ) # Test close_command with invalid ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.close_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -522,7 +522,7 @@ def test_command_management( assert result.status == CommandState.SUCCEEDED # Test get_execution_result with invalid ID - with pytest.raises(ProgrammingError) as excinfo: + with pytest.raises(ValueError) as excinfo: sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) From 72100b9504d7e81dedb635ad153f863be0aabc50 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 08:42:57 +0000 Subject: [PATCH 199/204] let type of rows by List[List[str]] for clarity Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 844d8eb73..0af37045e 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -130,7 +130,7 @@ def _convert_json_to_arrow_table(self, rows: List[List]) -> "pyarrow.Table": names = [col[0] for col in self.description] return pyarrow.Table.from_arrays(cols, names=names) - def _create_json_table(self, rows: List[List]) -> List[Row]: + def _create_json_table(self, rows: List[List[str]]) -> List[Row]: """ Convert raw data rows to Row objects with named columns based on description. @@ -143,7 +143,7 @@ def _create_json_table(self, rows: List[List]) -> List[Row]: ResultRow = Row(*[col[0] for col in self.description]) return [ResultRow(*self._convert_json_types(row)) for row in rows] - def fetchmany_json(self, size: int) -> List[List]: + def fetchmany_json(self, size: int) -> List[List[str]]: """ Fetch the next set of rows as a columnar table. @@ -165,7 +165,7 @@ def fetchmany_json(self, size: int) -> List[List]: return results - def fetchall_json(self) -> List[List]: + def fetchall_json(self) -> List[List[str]]: """ Fetch all remaining rows as a columnar table. From aab33a1c904a10e37be66b097db4b4fe45d721d4 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 2 Jul 2025 06:51:27 +0000 Subject: [PATCH 200/204] select Queue based on format in manifest Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/queue.py | 10 ++++--- tests/unit/test_sea_queue.py | 16 ++++++++-- tests/unit/test_sea_result_set.py | 39 ++++++++++++------------- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 88be098e6..e5e2e504f 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -5,6 +5,8 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.constants import ResultFormat +from databricks.sql.exc import ProgrammingError from databricks.sql.utils import ResultSetQueue @@ -35,15 +37,15 @@ def build_queue( ResultSetQueue: The appropriate queue for the result data """ - if sea_result_data.data is not None: + if manifest.format == ResultFormat.JSON_ARRAY.value: # INLINE disposition with JSON_ARRAY format return JsonQueue(sea_result_data.data) - elif sea_result_data.external_links is not None: + elif manifest.format == ResultFormat.ARROW_STREAM.value: # EXTERNAL_LINKS disposition raise NotImplementedError( "EXTERNAL_LINKS disposition is not implemented for SEA backend" ) - return JsonQueue([]) + raise ProgrammingError("Invalid result format") class JsonQueue(ResultSetQueue): @@ -53,7 +55,7 @@ def __init__(self, data_array): """Initialize with JSON array data.""" self.data_array = data_array self.cur_row_index = 0 - self.num_rows = len(data_array) + self.num_rows = len(data_array) if data_array else 0 def next_n_rows(self, num_rows): """Get the next n rows from the data array.""" diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index f06cd6ed5..93d3dc4d7 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -9,6 +9,7 @@ from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.sea.models.base import ResultData, ResultManifest +from databricks.sql.backend.sea.utils.constants import ResultFormat class TestJsonQueue: @@ -104,6 +105,15 @@ def mock_description(self): ("col3", "boolean", None, None, None, None, None), ] + def _create_empty_manifest(self, format: ResultFormat): + return ResultManifest( + format=format.value, + schema={}, + total_row_count=-1, + total_byte_count=-1, + total_chunk_count=-1, + ) + def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): """Test building a queue with inline JSON data.""" # Create sample data for inline JSON result @@ -116,7 +126,7 @@ def test_build_queue_with_inline_data(self, mock_sea_client, mock_description): result_data = ResultData(data=data, external_links=None, row_count=len(data)) # Create a manifest (not used for inline data) - manifest = None + manifest = self._create_empty_manifest(ResultFormat.JSON_ARRAY) # Build the queue queue = SeaResultSetQueueFactory.build_queue( @@ -140,7 +150,7 @@ def test_build_queue_with_empty_data(self, mock_sea_client, mock_description): # Build the queue queue = SeaResultSetQueueFactory.build_queue( result_data, - None, + self._create_empty_manifest(ResultFormat.JSON_ARRAY), "test-statement-123", description=mock_description, sea_client=mock_sea_client, @@ -165,7 +175,7 @@ def test_build_queue_with_external_links(self, mock_sea_client, mock_description ): SeaResultSetQueueFactory.build_queue( result_data, - None, + self._create_empty_manifest(ResultFormat.ARROW_STREAM), "test-statement-123", description=mock_description, sea_client=mock_sea_client, diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 677410708..544edaf96 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -10,6 +10,7 @@ from databricks.sql.backend.sea.result_set import SeaResultSet, Row from databricks.sql.backend.sea.queue import JsonQueue +from databricks.sql.backend.sea.utils.constants import ResultFormat from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.backend.sea.models.base import ResultData, ResultManifest @@ -59,6 +60,16 @@ def sample_data(self): ["value5", "5", "true"], ] + def _create_empty_manifest(self, format: ResultFormat): + """Create an empty manifest.""" + return ResultManifest( + format=format.value, + schema={}, + total_row_count=-1, + total_byte_count=-1, + total_chunk_count=-1, + ) + @pytest.fixture def result_set_with_data( self, mock_connection, mock_sea_client, execute_response, sample_data @@ -75,7 +86,7 @@ def result_set_with_data( execute_response=execute_response, sea_client=mock_sea_client, result_data=result_data, - manifest=None, + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -88,18 +99,6 @@ def json_queue(self, sample_data): """Create a JsonQueue with sample data.""" return JsonQueue(sample_data) - def empty_manifest(self): - """Create an empty manifest.""" - return ResultManifest( - format="JSON_ARRAY", - schema={}, - total_row_count=0, - total_byte_count=0, - total_chunk_count=0, - truncated=False, - is_volume_operation=False, - ) - def test_init_with_execute_response( self, mock_connection, mock_sea_client, execute_response ): @@ -109,7 +108,7 @@ def test_init_with_execute_response( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -130,7 +129,7 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -152,7 +151,7 @@ def test_close_when_already_closed_server_side( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -176,7 +175,7 @@ def test_close_when_connection_closed( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) @@ -332,7 +331,7 @@ def test_fetchmany_arrow_not_implemented( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, arraysize=100, ) @@ -352,7 +351,7 @@ def test_fetchall_arrow_not_implemented( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, arraysize=100, ) @@ -370,7 +369,7 @@ def test_is_staging_operation( execute_response=execute_response, sea_client=mock_sea_client, result_data=ResultData(data=[]), - manifest=self.empty_manifest(), + manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, arraysize=100, ) From 9a6db30aaac65b16d67de943d9dd1e510a0990bb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 2 Jul 2025 06:52:23 +0000 Subject: [PATCH 201/204] make manifest mandatory Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index e5e2e504f..ed8b9dcc2 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -14,7 +14,7 @@ class SeaResultSetQueueFactory(ABC): @staticmethod def build_queue( sea_result_data: ResultData, - manifest: Optional[ResultManifest], + manifest: ResultManifest, statement_id: str, description: List[Tuple] = [], max_download_threads: Optional[int] = None, From 0135d33c264adbae3124e4adf18b42468a1f35cd Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 2 Jul 2025 08:16:43 +0000 Subject: [PATCH 202/204] stronger type checking in JSON helper functions in Sea Result Set Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 0af37045e..302af5e3a 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING import logging @@ -82,7 +82,7 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) - def _convert_json_types(self, row: List) -> List: + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. """ @@ -109,7 +109,7 @@ def _convert_json_types(self, row: List) -> List: return converted_row - def _convert_json_to_arrow_table(self, rows: List[List]) -> "pyarrow.Table": + def _convert_json_to_arrow_table(self, rows: List[List[str]]) -> "pyarrow.Table": """ Convert raw data rows to Arrow table. From cc9db8bee2a54d06760c4736a0cb3eff8e961f5a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 2 Jul 2025 14:47:47 +0530 Subject: [PATCH 203/204] assign empty array to data array if None Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/queue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index ed8b9dcc2..5f92698cb 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -53,9 +53,9 @@ class JsonQueue(ResultSetQueue): def __init__(self, data_array): """Initialize with JSON array data.""" - self.data_array = data_array + self.data_array = data_array or [] self.cur_row_index = 0 - self.num_rows = len(data_array) if data_array else 0 + self.num_rows = len(self.data_array) def next_n_rows(self, num_rows): """Get the next n rows from the data array.""" From bb135fc9b9523c7be844cd659712db655d6a411a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 2 Jul 2025 15:05:25 +0530 Subject: [PATCH 204/204] stronger typing in JsonQueue Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/queue.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 5f92698cb..73f47ea96 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -51,20 +51,20 @@ def build_queue( class JsonQueue(ResultSetQueue): """Queue implementation for JSON_ARRAY format data.""" - def __init__(self, data_array): + def __init__(self, data_array: Optional[List[List[str]]]): """Initialize with JSON array data.""" self.data_array = data_array or [] self.cur_row_index = 0 self.num_rows = len(self.data_array) - def next_n_rows(self, num_rows): + def next_n_rows(self, num_rows: int) -> List[List[str]]: """Get the next n rows from the data array.""" length = min(num_rows, self.num_rows - self.cur_row_index) slice = self.data_array[self.cur_row_index : self.cur_row_index + length] self.cur_row_index += length return slice - def remaining_rows(self): + def remaining_rows(self) -> List[List[str]]: """Get all remaining rows from the data array.""" slice = self.data_array[self.cur_row_index :] self.cur_row_index += len(slice)