From 65a75f44e37c49359edf0e32d4549e02d4a31879 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 10 Jun 2025 17:29:56 +0530 Subject: [PATCH 01/86] added functionality for export of failure logs Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 76 ++++++++++++---- src/databricks/sql/exc.py | 19 +++- .../sql/telemetry/telemetry_client.py | 64 +++++++++---- src/databricks/sql/thrift_backend.py | 89 +++++++++++++------ 4 files changed, 184 insertions(+), 64 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3a682984e..f0f1f5b8a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -421,7 +421,10 @@ def cursor( Will throw an Error if the connection has been closed. """ if not self.open: - raise Error("Cannot create cursor from closed connection") + raise Error( + "Cannot create cursor from closed connection", + connection_uuid=self.get_session_id_hex(), + ) cursor = Cursor( self, @@ -471,7 +474,10 @@ def commit(self): pass def rollback(self): - raise NotSupportedError("Transactions are not supported on Databricks") + raise NotSupportedError( + "Transactions are not supported on Databricks", + connection_uuid=self.get_session_id_hex(), + ) class Cursor: @@ -523,7 +529,10 @@ def __iter__(self): for row in self.active_result_set: yield row else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def _determine_parameter_approach( self, params: Optional[TParameterCollection] @@ -660,7 +669,10 @@ def _close_and_clear_active_result_set(self): def _check_not_closed(self): if not self.open: - raise Error("Attempting operation on closed cursor") + raise Error( + "Attempting operation on closed cursor", + connection_uuid=self.connection.get_session_id_hex(), + ) def _handle_staging_operation( self, staging_allowed_local_path: Union[None, str, List[str]] @@ -678,7 +690,8 @@ def _handle_staging_operation( _staging_allowed_local_paths = staging_allowed_local_path else: raise Error( - "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands" + "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + connection_uuid=self.connection.get_session_id_hex(), ) abs_staging_allowed_local_paths = [ @@ -707,7 +720,8 @@ def _handle_staging_operation( continue if not allow_operation: raise Error( - "Local file operations are restricted to paths within the configured staging_allowed_local_path" + "Local file operations are restricted to paths within the configured staging_allowed_local_path", + connection_uuid=self.connection.get_session_id_hex(), ) # May be real headers, or could be json string @@ -737,7 +751,8 @@ def _handle_staging_operation( else: raise Error( f"Operation {row.operation} is not supported. " - + "Supported operations are GET, PUT, and REMOVE" + + "Supported operations are GET, PUT, and REMOVE", + connection_uuid=self.connection.get_session_id_hex(), ) def _handle_staging_put( @@ -749,7 +764,10 @@ def _handle_staging_put( """ if local_file is None: - raise Error("Cannot perform PUT without specifying a local_file") + raise Error( + "Cannot perform PUT without specifying a local_file", + connection_uuid=self.connection.get_session_id_hex(), + ) with open(local_file, "rb") as fh: r = requests.put(url=presigned_url, data=fh, headers=headers) @@ -766,7 +784,8 @@ def _handle_staging_put( if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + connection_uuid=self.connection.get_session_id_hex(), ) if r.status_code == ACCEPTED: @@ -784,7 +803,10 @@ def _handle_staging_get( """ if local_file is None: - raise Error("Cannot perform GET without specifying a local_file") + raise Error( + "Cannot perform GET without specifying a local_file", + connection_uuid=self.connection.get_session_id_hex(), + ) r = requests.get(url=presigned_url, headers=headers) @@ -792,7 +814,8 @@ def _handle_staging_get( # Any 2xx or 3xx will evaluate r.ok == True if not r.ok: raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + connection_uuid=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: @@ -807,7 +830,8 @@ def _handle_staging_remove( if not r.ok: raise Error( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + connection_uuid=self.connection.get_session_id_hex(), ) def execute( @@ -1006,7 +1030,8 @@ def get_async_execution_result(self): return self else: raise Error( - f"get_execution_result failed with Operation status {operation_state}" + f"get_execution_result failed with Operation status {operation_state}", + connection_uuid=self.connection.get_session_id_hex(), ) def executemany(self, operation, seq_of_parameters): @@ -1156,7 +1181,10 @@ def fetchall(self) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchall() else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchone(self) -> Optional[Row]: """ @@ -1170,7 +1198,10 @@ def fetchone(self) -> Optional[Row]: if self.active_result_set: return self.active_result_set.fetchone() else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchmany(self, size: int) -> List[Row]: """ @@ -1192,21 +1223,30 @@ def fetchmany(self, size: int) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchmany(size) else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) else: - raise Error("There is no active result set") + raise Error( + "There is no active result set", + connection_uuid=self.connection.get_session_id_hex(), + ) def cancel(self) -> None: """ diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3b27283a4..92577d548 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,8 +1,10 @@ import json import logging +import traceback -logger = logging.getLogger(__name__) +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +logger = logging.getLogger(__name__) ### PEP-249 Mandated ### class Error(Exception): @@ -11,10 +13,23 @@ class Error(Exception): `context`: Optional extra context about the error. MUST be JSON serializable """ - def __init__(self, message=None, context=None, *args, **kwargs): + def __init__( + self, message=None, context=None, connection_uuid=None, *args, **kwargs + ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} + self.connection_uuid = connection_uuid + + error_name = self.__class__.__name__ + if self.connection_uuid: + try: + telemetry_client = TelemetryClientFactory.get_telemetry_client( + self.connection_uuid + ) + telemetry_client.export_failure_log(error_name, self.message) + except Exception as telemetry_error: + logger.error(f"Failed to send error to telemetry: {telemetry_error}") def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index d095d685c..b2caa6c1f 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -8,6 +8,7 @@ from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, + DriverErrorInfo, ) from databricks.sql.telemetry.models.frontend_logs import ( TelemetryFrontendLog, @@ -26,7 +27,6 @@ import uuid import locale from abc import ABC, abstractmethod -from databricks.sql import __version__ logger = logging.getLogger(__name__) @@ -34,22 +34,26 @@ class TelemetryHelper: """Helper class for getting telemetry related information.""" - _DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( - driver_name="Databricks SQL Python Connector", - driver_version=__version__, - runtime_name=f"Python {sys.version.split()[0]}", - runtime_vendor=platform.python_implementation(), - runtime_version=platform.python_version(), - os_name=platform.system(), - os_version=platform.release(), - os_arch=platform.machine(), - client_app_name=None, # TODO: Add client app name - locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], - char_set_encoding=sys.getdefaultencoding(), - ) + _DRIVER_SYSTEM_CONFIGURATION = None @classmethod def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration: + if cls._DRIVER_SYSTEM_CONFIGURATION is None: + from databricks.sql import __version__ + + cls._DRIVER_SYSTEM_CONFIGURATION = DriverSystemConfiguration( + driver_name="Databricks SQL Python Connector", + driver_version=__version__, + runtime_name=f"Python {sys.version.split()[0]}", + runtime_vendor=platform.python_implementation(), + runtime_version=platform.python_version(), + os_name=platform.system(), + os_version=platform.release(), + os_arch=platform.machine(), + client_app_name=None, # TODO: Add client app name + locale_name=locale.getlocale()[0] or locale.getdefaultlocale()[0], + char_set_encoding=sys.getdefaultencoding(), + ) return cls._DRIVER_SYSTEM_CONFIGURATION @staticmethod @@ -99,7 +103,11 @@ class BaseTelemetryClient(ABC): """ @abstractmethod - def export_initial_telemetry_log(self, **kwargs): + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + pass + + @abstractmethod + def export_failure_log(self, error_name, error_message): pass @abstractmethod @@ -123,6 +131,9 @@ def __new__(cls): def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass + def export_failure_log(self, error_name, error_message): + pass + def close(self): pass @@ -255,10 +266,33 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): self.export_event(telemetry_frontend_log) + def export_failure_log(self, error_name, error_message): + logger.debug("Exporting failure log for connection %s", self._connection_uuid) + error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + error_info=error_info, + ) + ), + ) + self.export_event(telemetry_frontend_log) + def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) self.flush() + TelemetryClientFactory.close(self._connection_uuid) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e3dc38ad5..233b4f55c 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -223,6 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() + self._connection_uuid = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -255,12 +256,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response): + def _check_response_for_error(response, connection_uuid=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: - raise DatabaseError(response.status.errorMessage) + raise DatabaseError( + response.status.errorMessage, connection_uuid=connection_uuid + ) @staticmethod def _extract_error_message_from_headers(headers): @@ -311,7 +314,10 @@ def _handle_request_error(self, error_info, attempt, elapsed): no_retry_reason, attempt, elapsed ) network_request_error = RequestError( - user_friendly_error_message, full_error_info_context, error_info.error + user_friendly_error_message, + full_error_info_context, + self._connection_uuid, + error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -483,7 +489,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftBackend._check_response_for_error(response, self._connection_uuid) return response error_info = response_or_error_info @@ -497,7 +503,8 @@ def _check_protocol_version(self, t_open_session_resp): raise OperationalError( "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " - "instead got: {}".format(protocol_version) + "instead got: {}".format(protocol_version), + connection_uuid=self._connection_uuid, ) def _check_initial_namespace(self, catalog, schema, response): @@ -510,14 +517,16 @@ def _check_initial_namespace(self, catalog, schema, response): ): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " - "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." + "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", + connection_uuid=self._connection_uuid, ) if catalog: if not response.canUseMultipleCatalogs: raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " - + "but server does not support multiple catalogs.".format(catalog) # type: ignore + + "but server does not support multiple catalogs.".format(catalog), # type: ignore + connection_uuid=self._connection_uuid, ) def _check_session_configuration(self, session_configuration): @@ -531,7 +540,8 @@ def _check_session_configuration(self, session_configuration): "while using the Databricks SQL connector, it must be false not {}".format( TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], - ) + ), + connection_uuid=self._connection_uuid, ) def open_session(self, session_configuration, catalog, schema): @@ -562,6 +572,11 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) + self._connection_uuid = ( + self.handle_to_hex_id(response.sessionHandle) + if response.sessionHandle + else None + ) return response except: self._transport.close() @@ -586,6 +601,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, + connection_uuid=self._connection_uuid, ) else: raise ServerOperationError( @@ -595,6 +611,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, + connection_uuid=self._connection_uuid, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -605,6 +622,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, + connection_uuid=self._connection_uuid, ) def _poll_for_status(self, op_handle): @@ -625,7 +643,10 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: - raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set)) + raise OperationalError( + "Unsupported TRowSet instance {}".format(t_row_set), + connection_uuid=self._connection_uuid, + ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows def _get_metadata_resp(self, op_handle): @@ -633,7 +654,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema): + def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -664,7 +685,8 @@ def map_type(t_type_entry): # Current thriftserver implementation should always return a primitiveEntry, # even for complex types raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + connection_uuid=connection_uuid, ) def convert_col(t_column_desc): @@ -675,7 +697,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, connection_uuid=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -684,7 +706,8 @@ def _col_to_description(col): cleaned_type = (name[:-5] if name.endswith("_TYPE") else name).lower() else: raise OperationalError( - "Thrift protocol error: t_type_entry not a primitiveEntry" + "Thrift protocol error: t_type_entry not a primitiveEntry", + connection_uuid=connection_uuid, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -697,7 +720,8 @@ def _col_to_description(col): else: raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " - "primitiveEntry {}".format(type_entry.primitiveEntry) + "primitiveEntry {}".format(type_entry.primitiveEntry), + connection_uuid=connection_uuid, ) else: precision, scale = None, None @@ -705,9 +729,10 @@ def _col_to_description(col): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, connection_uuid=None): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, connection_uuid) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -727,7 +752,8 @@ def _results_message_to_execute_response(self, resp, operation_state): ttypes.TSparkRowSetType._VALUES_TO_NAMES[ t_result_set_metadata_resp.resultFormat ] - ) + ), + connection_uuid=self._connection_uuid, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -737,13 +763,15 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, self._connection_uuid ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -804,13 +832,15 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, self._connection_uuid ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -864,23 +894,23 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results): + def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus + t_spark_direct_results.operationStatus, connection_uuid ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata + t_spark_direct_results.resultSetMetadata, connection_uuid ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet + t_spark_direct_results.resultSet, connection_uuid ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation + t_spark_direct_results.closeOperation, connection_uuid ) def execute_command( @@ -1029,7 +1059,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1040,7 +1070,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) def fetch_results( self, @@ -1074,7 +1104,8 @@ def fetch_results( raise DataError( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset - ) + ), + connection_uuid=self._connection_uuid, ) queue = ResultSetQueueFactory.build_queue( From 5305308994e3ef3f1a2f76c0a8f147638a83a91c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 11 Jun 2025 09:36:41 +0530 Subject: [PATCH 02/86] changed logger.error to logger.debug in exc.py Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 92577d548..61d3a6234 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -29,7 +29,7 @@ def __init__( ) telemetry_client.export_failure_log(error_name, self.message) except Exception as telemetry_error: - logger.error(f"Failed to send error to telemetry: {telemetry_error}") + logger.debug(f"Failed to send error to telemetry: {telemetry_error}") def __str__(self): return self.message From ba83c33561f6f7e86b55bec3443be26fc8fc1c63 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 11 Jun 2025 11:27:53 +0530 Subject: [PATCH 03/86] Fix telemetry loss during Python shutdown Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index b2caa6c1f..eb0dee82c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -308,6 +308,8 @@ class TelemetryClientFactory: _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.Lock() # Thread safety for factory operations + _original_excepthook = None + _excepthook_installed = False @classmethod def _initialize(cls): @@ -318,11 +320,58 @@ def _initialize(cls): cls._executor = ThreadPoolExecutor( max_workers=10 ) # Thread pool for async operations TODO: Decide on max workers + cls._install_exception_hook() cls._initialized = True logger.debug( "TelemetryClientFactory initialized with thread pool (max_workers=10)" ) + @classmethod + def _install_exception_hook(cls): + """Install global exception handler for unhandled exceptions""" + if not cls._excepthook_installed: + import sys + + cls._original_excepthook = sys.excepthook + sys.excepthook = cls._handle_unhandled_exception + cls._excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") + + @classmethod + def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) + + try: + # Flush existing thread pool work and wait for completion + logger.debug( + "Flushing pending telemetry and waiting for thread pool completion..." + ) + for uuid, client in cls._clients.items(): + if hasattr(client, "flush"): + try: + client.flush() # Submit any pending events + except Exception as e: + logger.debug( + "Failed to flush telemetry for connection %s: %s", uuid, e + ) + + if cls._executor: + try: + cls._executor.shutdown( + wait=True + ) # This waits for all submitted work to complete + logger.debug("Thread pool shutdown completed successfully") + except Exception as e: + logger.debug("Thread pool shutdown failed: %s", e) + + except Exception as e: + logger.debug("Exception in excepthook telemetry handler: %s", e) + + # Call the original exception handler to maintain normal behavior + if cls._original_excepthook: + cls._original_excepthook(exc_type, exc_value, exc_traceback) + @staticmethod def initialize_telemetry_client( telemetry_enabled, From 131db92293771bda983e3305371c8de460281704 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 10:21:47 +0530 Subject: [PATCH 04/86] unit tests for export_failure_log Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 478205b18..b210d61b8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -83,6 +83,12 @@ def test_export_initial_telemetry_log(self, noop_telemetry_client): driver_connection_params=MagicMock(), user_agent="test" ) + def test_export_failure_log(self, noop_telemetry_client): + """Test that export_failure_log does nothing.""" + noop_telemetry_client.export_failure_log( + error_name="TestError", error_message="Test error message" + ) + def test_close(self, noop_telemetry_client): """Test that close does nothing.""" noop_telemetry_client.close() @@ -127,6 +133,47 @@ def test_export_initial_telemetry_log( mock_frontend_log.assert_called_once() client.export_event.assert_called_once_with(mock_frontend_log.return_value) + @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") + @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") + @patch("databricks.sql.telemetry.telemetry_client.time.time") + def test_export_failure_log( + self, + mock_time, + mock_uuid4, + mock_driver_error_info, + mock_get_driver_config, + mock_frontend_log, + telemetry_client_setup + ): + """Test exporting failure telemetry log.""" + mock_time.return_value = 2000 + mock_uuid4.return_value = "test-error-uuid" + mock_get_driver_config.return_value = "test-driver-config" + mock_driver_error_info.return_value = MagicMock() + mock_frontend_log.return_value = MagicMock() + + client = telemetry_client_setup["client"] + client.export_event = MagicMock() + + client._driver_connection_params = "test-connection-params" + client._user_agent = "test-user-agent" + + error_name = "TestError" + error_message = "This is a test error message" + + client.export_failure_log(error_name, error_message) + + mock_driver_error_info.assert_called_once_with( + error_name=error_name, + stack_trace=error_message + ) + + mock_frontend_log.assert_called_once() + + client.export_event.assert_called_once_with(mock_frontend_log.return_value) + def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" client = telemetry_client_setup["client"] From 3abc40dcaa39e6ebfb527a0019f355d95a53164f Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 10:56:34 +0530 Subject: [PATCH 05/86] try-catch blocks to make telemetry failures non-blocking for connector operations Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 98 +++++++++++-------- 1 file changed, 57 insertions(+), 41 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index eb0dee82c..eb4edec18 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -244,56 +244,72 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): "Exporting initial telemetry log for connection %s", self._connection_uuid ) - self._driver_connection_params = driver_connection_params - self._user_agent = user_agent - - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), - driver_connection_params=self._driver_connection_params, - ) - ), - ) + try: + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + ) + ), + ) - self.export_event(telemetry_frontend_log) + self.export_event(telemetry_frontend_log) + except Exception as e: + logger.debug("Failed to export initial telemetry log: %s", e) def export_failure_log(self, error_name, error_message): logger.debug("Exporting failure log for connection %s", self._connection_uuid) - error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), - driver_connection_params=self._driver_connection_params, - error_info=error_info, - ) - ), - ) - self.export_event(telemetry_frontend_log) + try: + error_info = DriverErrorInfo( + error_name=error_name, stack_trace=error_message + ) + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._connection_uuid, + system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + driver_connection_params=self._driver_connection_params, + error_info=error_info, + ) + ), + ) + self.export_event(telemetry_frontend_log) + except Exception as e: + logger.debug("Failed to export failure log: %s", e) def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) - self.flush() + try: + self.flush() + except Exception as e: + logger.debug("Failed to flush telemetry during close: %s", e) - TelemetryClientFactory.close(self._connection_uuid) + try: + TelemetryClientFactory.close(self._connection_uuid) + except Exception as e: + logger.debug( + "Failed to remove telemetry client from telemetry clientfactory: %s", e + ) class TelemetryClientFactory: From ffa47872a063ba8bdc93da126b67a3baecaa07a7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 11:55:38 +0530 Subject: [PATCH 06/86] removed redundant try/catch blocks, added try/catch block to initialize and get telemetry client Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 11 +- .../sql/telemetry/telemetry_client.py | 107 +++++++++--------- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 61d3a6234..d7bcd5c61 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -23,13 +23,10 @@ def __init__( error_name = self.__class__.__name__ if self.connection_uuid: - try: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - self.connection_uuid - ) - telemetry_client.export_failure_log(error_name, self.message) - except Exception as telemetry_error: - logger.debug(f"Failed to send error to telemetry: {telemetry_error}") + telemetry_client = TelemetryClientFactory.get_telemetry_client( + self.connection_uuid + ) + telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index eb4edec18..216262b31 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -301,15 +301,9 @@ def close(self): logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) try: self.flush() - except Exception as e: - logger.debug("Failed to flush telemetry during close: %s", e) - - try: TelemetryClientFactory.close(self._connection_uuid) except Exception as e: - logger.debug( - "Failed to remove telemetry client from telemetry clientfactory: %s", e - ) + logger.debug("Failed to close telemetry client: %s", e) class TelemetryClientFactory: @@ -358,31 +352,27 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) - try: - # Flush existing thread pool work and wait for completion - logger.debug( - "Flushing pending telemetry and waiting for thread pool completion..." - ) - for uuid, client in cls._clients.items(): - if hasattr(client, "flush"): - try: - client.flush() # Submit any pending events - except Exception as e: - logger.debug( - "Failed to flush telemetry for connection %s: %s", uuid, e - ) - - if cls._executor: + # Flush existing thread pool work and wait for completion + logger.debug( + "Flushing pending telemetry and waiting for thread pool completion..." + ) + for uuid, client in cls._clients.items(): + if hasattr(client, "flush"): try: - cls._executor.shutdown( - wait=True - ) # This waits for all submitted work to complete - logger.debug("Thread pool shutdown completed successfully") + client.flush() # Submit any pending events except Exception as e: - logger.debug("Thread pool shutdown failed: %s", e) + logger.debug( + "Failed to flush telemetry for connection %s: %s", uuid, e + ) - except Exception as e: - logger.debug("Exception in excepthook telemetry handler: %s", e) + if cls._executor: + try: + cls._executor.shutdown( + wait=True + ) # This waits for all submitted work to complete + logger.debug("Thread pool shutdown completed successfully") + except Exception as e: + logger.debug("Thread pool shutdown failed: %s", e) # Call the original exception handler to maintain normal behavior if cls._original_excepthook: @@ -396,35 +386,48 @@ def initialize_telemetry_client( host_url, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" - TelemetryClientFactory._initialize() + try: + TelemetryClientFactory._initialize() - with TelemetryClientFactory._lock: - if connection_uuid not in TelemetryClientFactory._clients: - logger.debug( - "Creating new TelemetryClient for connection %s", connection_uuid - ) - if telemetry_enabled: - TelemetryClientFactory._clients[connection_uuid] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, + with TelemetryClientFactory._lock: + if connection_uuid not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", + connection_uuid, ) - else: - TelemetryClientFactory._clients[ - connection_uuid - ] = NoopTelemetryClient() + if telemetry_enabled: + TelemetryClientFactory._clients[ + connection_uuid + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + connection_uuid + ] = NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + TelemetryClientFactory._clients[connection_uuid] = NoopTelemetryClient() @staticmethod def get_telemetry_client(connection_uuid): """Get the telemetry client for a specific connection""" - if connection_uuid in TelemetryClientFactory._clients: - return TelemetryClientFactory._clients[connection_uuid] - else: - logger.error( - "Telemetry client not initialized for connection %s", connection_uuid - ) + try: + if connection_uuid in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[connection_uuid] + else: + logger.error( + "Telemetry client not initialized for connection %s", + connection_uuid, + ) + return NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) return NoopTelemetryClient() @staticmethod From cc077f3b6032bee52e99dfe948fd26b3dc9911be Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 13:54:05 +0530 Subject: [PATCH 07/86] skip null fields in telemetry request Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/models/event.py | 16 +++++++------- .../sql/telemetry/models/frontend_logs.py | 10 ++++----- src/databricks/sql/telemetry/utils.py | 21 +++++++++++++++++++ 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index 4429a7626..c00738810 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -9,7 +9,7 @@ ExecutionResultFormat, ) from typing import Optional -from databricks.sql.telemetry.utils import EnumEncoder +from databricks.sql.telemetry.utils import to_json_compact @dataclass @@ -26,7 +26,7 @@ class HostDetails: port: int def to_json(self): - return json.dumps(asdict(self)) + return to_json_compact(self) @dataclass @@ -52,7 +52,7 @@ class DriverConnectionParameters: socket_timeout: Optional[int] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -88,7 +88,7 @@ class DriverSystemConfiguration: locale_name: Optional[str] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -106,7 +106,7 @@ class DriverVolumeOperation: volume_path: str def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -124,7 +124,7 @@ class DriverErrorInfo: stack_trace: str def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -146,7 +146,7 @@ class SqlExecutionEvent: retry_count: int def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -179,4 +179,4 @@ class TelemetryEvent: operation_latency_ms: Optional[int] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py index 36086a7cc..f5d58a4be 100644 --- a/src/databricks/sql/telemetry/models/frontend_logs.py +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass, asdict from databricks.sql.telemetry.models.event import TelemetryEvent -from databricks.sql.telemetry.utils import EnumEncoder +from databricks.sql.telemetry.utils import to_json_compact from typing import Optional @@ -20,7 +20,7 @@ class TelemetryClientContext: user_agent: str def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -36,7 +36,7 @@ class FrontendLogContext: client_context: TelemetryClientContext def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -52,7 +52,7 @@ class FrontendLogEntry: sql_driver_log: TelemetryEvent def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) @dataclass @@ -75,4 +75,4 @@ class TelemetryFrontendLog: workspace_id: Optional[int] = None def to_json(self): - return json.dumps(asdict(self), cls=EnumEncoder) + return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 6a4d64eba..8be2c9873 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,5 +1,6 @@ import json from enum import Enum +from dataclasses import asdict class EnumEncoder(json.JSONEncoder): @@ -13,3 +14,23 @@ def default(self, obj): if isinstance(obj, Enum): return obj.value return super().default(obj) + + +def filter_none_values(data): + """ + Recursively remove None values from dictionaries. + This reduces telemetry payload size by excluding null fields. + """ + if isinstance(data, dict): + return {k: filter_none_values(v) for k, v in data.items() if v is not None} + else: + return data + + +def to_json_compact(dataclass_obj): + """ + Convert a dataclass to JSON string, excluding None values. + """ + data_dict = asdict(dataclass_obj) + filtered_dict = filter_none_values(data_dict) + return json.dumps(filtered_dict, cls=EnumEncoder) From 2c6fd44cb18b9c8b8f910d34bc7e02a24077e58b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 12 Jun 2025 23:12:39 +0530 Subject: [PATCH 08/86] removed dup import, renamed func, changed a filter_null_values to lamda Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 8 +++---- src/databricks/sql/telemetry/utils.py | 21 +++++++------------ tests/unit/test_telemetry.py | 4 ++-- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 216262b31..1099c81cd 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -37,7 +37,7 @@ class TelemetryHelper: _DRIVER_SYSTEM_CONFIGURATION = None @classmethod - def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration: + def get_driver_system_configuration(cls) -> DriverSystemConfiguration: if cls._DRIVER_SYSTEM_CONFIGURATION is None: from databricks.sql import __version__ @@ -259,7 +259,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, ) ), @@ -286,7 +286,7 @@ def export_failure_log(self, error_name, error_message): entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( session_id=self._connection_uuid, - system_configuration=TelemetryHelper.getDriverSystemConfiguration(), + system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, error_info=error_info, ) @@ -340,8 +340,6 @@ def _initialize(cls): def _install_exception_hook(cls): """Install global exception handler for unhandled exceptions""" if not cls._excepthook_installed: - import sys - cls._original_excepthook = sys.excepthook sys.excepthook = cls._handle_unhandled_exception cls._excepthook_installed = True diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 8be2c9873..2ae87b96e 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -16,21 +16,14 @@ def default(self, obj): return super().default(obj) -def filter_none_values(data): - """ - Recursively remove None values from dictionaries. - This reduces telemetry payload size by excluding null fields. - """ - if isinstance(data, dict): - return {k: filter_none_values(v) for k, v in data.items() if v is not None} - else: - return data - - def to_json_compact(dataclass_obj): """ Convert a dataclass to JSON string, excluding None values. """ - data_dict = asdict(dataclass_obj) - filtered_dict = filter_none_values(data_dict) - return json.dumps(filtered_dict, cls=EnumEncoder) + return json.dumps( + asdict( + dataclass_obj, + dict_factory=lambda data: {k: v for k, v in data if v is not None}, + ), + cls=EnumEncoder, + ) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index b210d61b8..a3e0239db 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -98,7 +98,7 @@ class TestTelemetryClient: """Tests for the TelemetryClient class.""" @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") @patch("databricks.sql.telemetry.telemetry_client.time.time") def test_export_initial_telemetry_log( @@ -134,7 +134,7 @@ def test_export_initial_telemetry_log( client.export_event.assert_called_once_with(mock_frontend_log.return_value) @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") @patch("databricks.sql.telemetry.telemetry_client.time.time") From 89540a169e101c368334f8065e3bf1b1573cf2b7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 02:26:24 +0530 Subject: [PATCH 09/86] removed unnecassary class variable and a redundant try/except block Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 5 ++--- src/databricks/sql/telemetry/telemetry_client.py | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index d7bcd5c61..cc7a47cb4 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -19,12 +19,11 @@ def __init__( super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} - self.connection_uuid = connection_uuid error_name = self.__class__.__name__ - if self.connection_uuid: + if connection_uuid: telemetry_client = TelemetryClientFactory.get_telemetry_client( - self.connection_uuid + connection_uuid ) telemetry_client.export_failure_log(error_name, self.message) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 1099c81cd..3402cfd70 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -356,12 +356,7 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): ) for uuid, client in cls._clients.items(): if hasattr(client, "flush"): - try: - client.flush() # Submit any pending events - except Exception as e: - logger.debug( - "Failed to flush telemetry for connection %s: %s", uuid, e - ) + client.flush() # Submit any pending events if cls._executor: try: From 52a1152b33a39be7a7b155ffcacb6b482829ddd2 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 03:16:04 +0530 Subject: [PATCH 10/86] public functions defined at interface level Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 3402cfd70..f2212d7a4 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -102,6 +102,14 @@ class BaseTelemetryClient(ABC): It is used to define the interface for telemetry clients. """ + @abstractmethod + def export_event(self, event): + pass + + @abstractmethod + def flush(self): + pass + @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -128,6 +136,12 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance + def export_event(self, event): + pass + + def flush(self): + pass + def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -354,9 +368,8 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): logger.debug( "Flushing pending telemetry and waiting for thread pool completion..." ) - for uuid, client in cls._clients.items(): - if hasattr(client, "flush"): - client.flush() # Submit any pending events + for client in cls._clients.items(): + client.flush() # Submit any pending events if cls._executor: try: From 3dcdcfa70ff4562da24fa41e006392f2825602fb Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 11:06:58 +0530 Subject: [PATCH 11/86] changed export_event and flush to private functions Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 2 +- .../sql/telemetry/telemetry_client.py | 44 ++++--------------- tests/unit/test_telemetry.py | 18 ++++---- 3 files changed, 17 insertions(+), 47 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f0f1f5b8a..a4f56c4f9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -467,7 +467,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - self._telemetry_client.close() + TelemetryClientFactory.close(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f2212d7a4..f6e3daad1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -102,14 +102,6 @@ class BaseTelemetryClient(ABC): It is used to define the interface for telemetry clients. """ - @abstractmethod - def export_event(self, event): - pass - - @abstractmethod - def flush(self): - pass - @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -136,12 +128,6 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance - def export_event(self, event): - pass - - def flush(self): - pass - def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -182,7 +168,7 @@ def __init__( self._host_url = host_url self._executor = executor - def export_event(self, event): + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._connection_uuid) with self._lock: @@ -191,9 +177,9 @@ def export_event(self, event): logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) - self.flush() + self._flush() - def flush(self): + def _flush(self): """Flush the current batch of events to the server""" with self._lock: events_to_flush = self._events_batch.copy() @@ -313,11 +299,7 @@ def export_failure_log(self, error_name, error_message): def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) - try: - self.flush() - TelemetryClientFactory.close(self._connection_uuid) - except Exception as e: - logger.debug("Failed to close telemetry client: %s", e) + self._flush() class TelemetryClientFactory: @@ -365,20 +347,8 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): logger.debug("Handling unhandled exception: %s", exc_type.__name__) # Flush existing thread pool work and wait for completion - logger.debug( - "Flushing pending telemetry and waiting for thread pool completion..." - ) - for client in cls._clients.items(): - client.flush() # Submit any pending events - - if cls._executor: - try: - cls._executor.shutdown( - wait=True - ) # This waits for all submitted work to complete - logger.debug("Thread pool shutdown completed successfully") - except Exception as e: - logger.debug("Thread pool shutdown failed: %s", e) + for uuid, _ in cls._clients.items(): + cls.close(uuid) # Call the original exception handler to maintain normal behavior if cls._original_excepthook: @@ -445,6 +415,7 @@ def close(connection_uuid): logger.debug( "Removing telemetry client for connection %s", connection_uuid ) + TelemetryClientFactory.get_telemetry_client(connection_uuid).close() TelemetryClientFactory._clients.pop(connection_uuid, None) # Shutdown executor if no more clients @@ -455,3 +426,4 @@ def close(connection_uuid): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False + diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index a3e0239db..97b8f276b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -177,18 +177,18 @@ def test_export_failure_log( def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" client = telemetry_client_setup["client"] - client.flush = MagicMock() + client._flush = MagicMock() for i in range(5): - client.export_event(f"event-{i}") + client._export_event(f"event-{i}") - client.flush.assert_not_called() + client._flush.assert_not_called() assert len(client._events_batch) == 5 for i in range(5, 10): - client.export_event(f"event-{i}") + client._export_event(f"event-{i}") - client.flush.assert_called_once() + client._flush.assert_called_once() assert len(client._events_batch) == 10 @patch("requests.post") @@ -244,7 +244,7 @@ def test_flush(self, telemetry_client_setup): client._events_batch = ["event1", "event2"] client._send_telemetry = MagicMock() - client.flush() + client._flush() client._send_telemetry.assert_called_once_with(["event1", "event2"]) assert client._events_batch == [] @@ -253,13 +253,11 @@ def test_flush(self, telemetry_client_setup): def test_close(self, mock_factory_class, telemetry_client_setup): """Test closing the client.""" client = telemetry_client_setup["client"] - connection_uuid = telemetry_client_setup["connection_uuid"] - client.flush = MagicMock() + client._flush = MagicMock() client.close() - client.flush.assert_called_once() - mock_factory_class.close.assert_called_once_with(connection_uuid) + client._flush.assert_called_once() class TestTelemetryClientFactory: From b2714c9738439d7cb2c947fd2654c5ed7444fe99 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 11:10:10 +0530 Subject: [PATCH 12/86] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f6e3daad1..fe1c0e191 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -426,4 +426,3 @@ def close(connection_uuid): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - From 377a87bb2f493b6442c8a921b7c3e51e3c72b44d Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 11:28:34 +0530 Subject: [PATCH 13/86] changed connection_uuid to thread local in thrift backend Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/thrift_backend.py | 77 ++++++++++++++-------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 233b4f55c..79fc7f1b0 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -72,6 +72,9 @@ "_retry_delay_default": (float, 5, 1, 60), } +# Add thread local storage +_connection_uuid = threading.local() + class ThriftBackend: CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE @@ -223,7 +226,7 @@ def __init__( raise self._request_lock = threading.RLock() - self._connection_uuid = None + _connection_uuid.value = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -256,13 +259,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, connection_uuid=None): + def _check_response_for_error(response): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( - response.status.errorMessage, connection_uuid=connection_uuid + response.status.errorMessage, + connection_uuid=getattr(_connection_uuid, "value", None), ) @staticmethod @@ -316,7 +320,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._connection_uuid, + getattr(_connection_uuid, "value", None), error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -489,7 +493,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response, self._connection_uuid) + ThriftBackend._check_response_for_error(response) return response error_info = response_or_error_info @@ -504,7 +508,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def _check_initial_namespace(self, catalog, schema, response): @@ -518,7 +522,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) if catalog: @@ -526,7 +530,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def _check_session_configuration(self, session_configuration): @@ -541,7 +545,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def open_session(self, session_configuration, catalog, schema): @@ -572,7 +576,7 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - self._connection_uuid = ( + _connection_uuid.value = ( self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None @@ -601,7 +605,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) else: raise ServerOperationError( @@ -611,7 +615,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -622,7 +626,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def _poll_for_status(self, op_handle): @@ -645,7 +649,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -654,7 +658,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): + def _hive_schema_to_arrow_schema(t_table_schema): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -686,7 +690,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) def convert_col(t_column_desc): @@ -697,7 +701,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, connection_uuid=None): + def _col_to_description(col): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -707,7 +711,7 @@ def _col_to_description(col, connection_uuid=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -721,7 +725,7 @@ def _col_to_description(col, connection_uuid=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - connection_uuid=connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) else: precision, scale = None, None @@ -729,10 +733,9 @@ def _col_to_description(col, connection_uuid=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, connection_uuid=None): + def _hive_schema_to_description(t_table_schema): return [ - ThriftBackend._col_to_description(col, connection_uuid) - for col in t_table_schema.columns + ThriftBackend._col_to_description(col) for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -753,7 +756,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -763,15 +766,13 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid - ) + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) .serialize() .to_pybytes() ) @@ -832,15 +833,13 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid - ) + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) .serialize() .to_pybytes() ) @@ -894,23 +893,23 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): + def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus, connection_uuid + t_spark_direct_results.operationStatus ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata, connection_uuid + t_spark_direct_results.resultSetMetadata ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet, connection_uuid + t_spark_direct_results.resultSet ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation, connection_uuid + t_spark_direct_results.closeOperation ) def execute_command( @@ -1059,7 +1058,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1070,7 +1069,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults) def fetch_results( self, @@ -1105,7 +1104,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - connection_uuid=self._connection_uuid, + connection_uuid=getattr(_connection_uuid, "value", None), ) queue = ResultSetQueueFactory.build_queue( From c9376b8b8ff36f9f4f3809cf4f19bf13d0b52c4b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 12:04:42 +0530 Subject: [PATCH 14/86] made errors more specific Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 37 +++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a4f56c4f9..f9a011b11 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -18,6 +18,9 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + InterfaceError, + NotSupportedError, + ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -421,7 +424,7 @@ def cursor( Will throw an Error if the connection has been closed. """ if not self.open: - raise Error( + raise InterfaceError( "Cannot create cursor from closed connection", connection_uuid=self.get_session_id_hex(), ) @@ -529,7 +532,7 @@ def __iter__(self): for row in self.active_result_set: yield row else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -669,7 +672,7 @@ def _close_and_clear_active_result_set(self): def _check_not_closed(self): if not self.open: - raise Error( + raise InterfaceError( "Attempting operation on closed cursor", connection_uuid=self.connection.get_session_id_hex(), ) @@ -689,7 +692,7 @@ def _handle_staging_operation( elif isinstance(staging_allowed_local_path, type(list())): _staging_allowed_local_paths = staging_allowed_local_path else: - raise Error( + raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", connection_uuid=self.connection.get_session_id_hex(), ) @@ -719,7 +722,7 @@ def _handle_staging_operation( else: continue if not allow_operation: - raise Error( + raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", connection_uuid=self.connection.get_session_id_hex(), ) @@ -749,7 +752,7 @@ def _handle_staging_operation( handler_args.pop("local_file") return self._handle_staging_remove(**handler_args) else: - raise Error( + raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", connection_uuid=self.connection.get_session_id_hex(), @@ -764,7 +767,7 @@ def _handle_staging_put( """ if local_file is None: - raise Error( + raise ProgrammingError( "Cannot perform PUT without specifying a local_file", connection_uuid=self.connection.get_session_id_hex(), ) @@ -783,7 +786,7 @@ def _handle_staging_put( # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: - raise Error( + raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -803,7 +806,7 @@ def _handle_staging_get( """ if local_file is None: - raise Error( + raise ProgrammingError( "Cannot perform GET without specifying a local_file", connection_uuid=self.connection.get_session_id_hex(), ) @@ -813,7 +816,7 @@ def _handle_staging_get( # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True if not r.ok: - raise Error( + raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -829,7 +832,7 @@ def _handle_staging_remove( r = requests.delete(url=presigned_url, headers=headers) if not r.ok: - raise Error( + raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1029,7 +1032,7 @@ def get_async_execution_result(self): return self else: - raise Error( + raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1181,7 +1184,7 @@ def fetchall(self) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchall() else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1198,7 +1201,7 @@ def fetchone(self) -> Optional[Row]: if self.active_result_set: return self.active_result_set.fetchone() else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1223,7 +1226,7 @@ def fetchmany(self, size: int) -> List[Row]: if self.active_result_set: return self.active_result_set.fetchmany(size) else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1233,7 +1236,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": if self.active_result_set: return self.active_result_set.fetchall_arrow() else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) @@ -1243,7 +1246,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) else: - raise Error( + raise ProgrammingError( "There is no active result set", connection_uuid=self.connection.get_session_id_hex(), ) From bbfadf2b16f53d317e0ca4ba36b5b20f30eea533 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 14:14:27 +0530 Subject: [PATCH 15/86] revert change to connection_uuid Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/thrift_backend.py | 82 +++++++++++++++------------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 79fc7f1b0..7c47da2b1 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -72,9 +72,6 @@ "_retry_delay_default": (float, 5, 1, 60), } -# Add thread local storage -_connection_uuid = threading.local() - class ThriftBackend: CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE @@ -226,7 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() - _connection_uuid.value = None + self._connection_uuid = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -259,14 +256,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response): + def _check_response_for_error(response, connection_uuid=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) @staticmethod @@ -320,7 +317,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - getattr(_connection_uuid, "value", None), + self._connection_uuid, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -493,7 +490,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftBackend._check_response_for_error(response, self._connection_uuid) return response error_info = response_or_error_info @@ -508,7 +505,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def _check_initial_namespace(self, catalog, schema, response): @@ -522,7 +519,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) if catalog: @@ -530,7 +527,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def _check_session_configuration(self, session_configuration): @@ -545,7 +542,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def open_session(self, session_configuration, catalog, schema): @@ -576,7 +573,7 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - _connection_uuid.value = ( + self._connection_uuid = ( self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None @@ -605,7 +602,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) else: raise ServerOperationError( @@ -615,7 +612,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -626,7 +623,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) def _poll_for_status(self, op_handle): @@ -649,7 +646,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -658,7 +655,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema): + def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -690,7 +687,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) def convert_col(t_column_desc): @@ -701,7 +698,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col): + def _col_to_description(col, connection_uuid=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -711,7 +708,7 @@ def _col_to_description(col): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -725,7 +722,7 @@ def _col_to_description(col): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=connection_uuid, ) else: precision, scale = None, None @@ -733,9 +730,10 @@ def _col_to_description(col): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema): + def _hive_schema_to_description(t_table_schema, connection_uuid=None): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftBackend._col_to_description(col, connection_uuid) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -756,7 +754,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -766,13 +764,16 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._connection_uuid, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -833,13 +834,16 @@ def get_execution_result(self, op_handle, cursor): is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema + t_result_set_metadata_resp.schema, + self._connection_uuid, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + or self._hive_schema_to_arrow_schema( + t_result_set_metadata_resp.schema, self._connection_uuid + ) .serialize() .to_pybytes() ) @@ -893,23 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results): + def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( - t_spark_direct_results.operationStatus + t_spark_direct_results.operationStatus, + connection_uuid, ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSetMetadata + t_spark_direct_results.resultSetMetadata, + connection_uuid, ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( - t_spark_direct_results.resultSet + t_spark_direct_results.resultSet, + connection_uuid, ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( - t_spark_direct_results.closeOperation + t_spark_direct_results.closeOperation, + connection_uuid, ) def execute_command( @@ -1058,7 +1066,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1069,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults) + self._check_direct_results_for_error(resp.directResults, self._connection_uuid) def fetch_results( self, @@ -1104,7 +1112,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - connection_uuid=getattr(_connection_uuid, "value", None), + connection_uuid=self._connection_uuid, ) queue = ResultSetQueueFactory.build_queue( From 9bce26b3eac736bf896850ca5448d4c831febde9 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 14:24:21 +0530 Subject: [PATCH 16/86] reverting change in close in telemetry client Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 2 +- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f9a011b11..04ef4584f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -470,7 +470,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - TelemetryClientFactory.close(self.get_session_id_hex()) + self._telemetry_client.close() def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index fe1c0e191..ddb0a3974 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -300,6 +300,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) self._flush() + TelemetryClientFactory.close(self._connection_uuid) class TelemetryClientFactory: @@ -415,7 +416,6 @@ def close(connection_uuid): logger.debug( "Removing telemetry client for connection %s", connection_uuid ) - TelemetryClientFactory.get_telemetry_client(connection_uuid).close() TelemetryClientFactory._clients.pop(connection_uuid, None) # Shutdown executor if no more clients From ef4514d4020e164fbc02025c9ff7bc8c831fb360 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 16:23:35 +0530 Subject: [PATCH 17/86] JsonSerializableMixin Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/models/event.py | 40 +++++-------------- .../sql/telemetry/models/frontend_logs.py | 25 +++--------- .../sql/telemetry/telemetry_client.py | 11 ++--- src/databricks/sql/telemetry/utils.py | 32 +++++++++------ tests/unit/test_telemetry.py | 8 ++-- 5 files changed, 44 insertions(+), 72 deletions(-) diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index c00738810..f5496deec 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -1,5 +1,4 @@ -import json -from dataclasses import dataclass, asdict +from dataclasses import dataclass from databricks.sql.telemetry.models.enums import ( AuthMech, AuthFlow, @@ -9,11 +8,11 @@ ExecutionResultFormat, ) from typing import Optional -from databricks.sql.telemetry.utils import to_json_compact +from databricks.sql.telemetry.utils import JsonSerializableMixin @dataclass -class HostDetails: +class HostDetails(JsonSerializableMixin): """ Represents the host connection details for a Databricks workspace. @@ -25,12 +24,9 @@ class HostDetails: host_url: str port: int - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverConnectionParameters: +class DriverConnectionParameters(JsonSerializableMixin): """ Contains all connection parameters used to establish a connection to Databricks SQL. This includes authentication details, host information, and connection settings. @@ -51,12 +47,9 @@ class DriverConnectionParameters: auth_flow: Optional[AuthFlow] = None socket_timeout: Optional[int] = None - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverSystemConfiguration: +class DriverSystemConfiguration(JsonSerializableMixin): """ Contains system-level configuration information about the client environment. This includes details about the operating system, runtime, and driver version. @@ -87,12 +80,9 @@ class DriverSystemConfiguration: client_app_name: Optional[str] = None locale_name: Optional[str] = None - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverVolumeOperation: +class DriverVolumeOperation(JsonSerializableMixin): """ Represents a volume operation performed by the driver. Used for tracking volume-related operations in telemetry. @@ -105,12 +95,9 @@ class DriverVolumeOperation: volume_operation_type: DriverVolumeOperationType volume_path: str - def to_json(self): - return to_json_compact(self) - @dataclass -class DriverErrorInfo: +class DriverErrorInfo(JsonSerializableMixin): """ Contains detailed information about errors that occur during driver operations. Used for error tracking and debugging in telemetry. @@ -123,12 +110,9 @@ class DriverErrorInfo: error_name: str stack_trace: str - def to_json(self): - return to_json_compact(self) - @dataclass -class SqlExecutionEvent: +class SqlExecutionEvent(JsonSerializableMixin): """ Represents a SQL query execution event. Contains details about the query execution, including type, compression, and result format. @@ -145,12 +129,9 @@ class SqlExecutionEvent: execution_result: ExecutionResultFormat retry_count: int - def to_json(self): - return to_json_compact(self) - @dataclass -class TelemetryEvent: +class TelemetryEvent(JsonSerializableMixin): """ Main telemetry event class that aggregates all telemetry data. Contains information about the session, system configuration, connection parameters, @@ -177,6 +158,3 @@ class TelemetryEvent: sql_operation: Optional[SqlExecutionEvent] = None error_info: Optional[DriverErrorInfo] = None operation_latency_ms: Optional[int] = None - - def to_json(self): - return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/models/frontend_logs.py b/src/databricks/sql/telemetry/models/frontend_logs.py index f5d58a4be..4cc314ec3 100644 --- a/src/databricks/sql/telemetry/models/frontend_logs.py +++ b/src/databricks/sql/telemetry/models/frontend_logs.py @@ -1,12 +1,11 @@ -import json -from dataclasses import dataclass, asdict +from dataclasses import dataclass from databricks.sql.telemetry.models.event import TelemetryEvent -from databricks.sql.telemetry.utils import to_json_compact +from databricks.sql.telemetry.utils import JsonSerializableMixin from typing import Optional @dataclass -class TelemetryClientContext: +class TelemetryClientContext(JsonSerializableMixin): """ Contains client-side context information for telemetry events. This includes timestamp and user agent information for tracking when and how the client is being used. @@ -19,12 +18,9 @@ class TelemetryClientContext: timestamp_millis: int user_agent: str - def to_json(self): - return to_json_compact(self) - @dataclass -class FrontendLogContext: +class FrontendLogContext(JsonSerializableMixin): """ Wrapper for client context information in frontend logs. Provides additional context about the client environment for telemetry events. @@ -35,12 +31,9 @@ class FrontendLogContext: client_context: TelemetryClientContext - def to_json(self): - return to_json_compact(self) - @dataclass -class FrontendLogEntry: +class FrontendLogEntry(JsonSerializableMixin): """ Contains the actual telemetry event data in a frontend log. Wraps the SQL driver log information for frontend processing. @@ -51,12 +44,9 @@ class FrontendLogEntry: sql_driver_log: TelemetryEvent - def to_json(self): - return to_json_compact(self) - @dataclass -class TelemetryFrontendLog: +class TelemetryFrontendLog(JsonSerializableMixin): """ Main container for frontend telemetry data. Aggregates workspace information, event ID, context, and the actual log entry. @@ -73,6 +63,3 @@ class TelemetryFrontendLog: context: FrontendLogContext entry: FrontendLogEntry workspace_id: Optional[int] = None - - def to_json(self): - return to_json_compact(self) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index ddb0a3974..403379aff 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -265,7 +265,8 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): ), ) - self.export_event(telemetry_frontend_log) + self._export_event(telemetry_frontend_log) + except Exception as e: logger.debug("Failed to export initial telemetry log: %s", e) @@ -292,7 +293,7 @@ def export_failure_log(self, error_name, error_message): ) ), ) - self.export_event(telemetry_frontend_log) + self._export_event(telemetry_frontend_log) except Exception as e: logger.debug("Failed to export failure log: %s", e) @@ -347,9 +348,9 @@ def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) - # Flush existing thread pool work and wait for completion - for uuid, _ in cls._clients.items(): - cls.close(uuid) + clients_to_close = list(cls._clients.values()) + for client in clients_to_close: + client.close() # Call the original exception handler to maintain normal behavior if cls._original_excepthook: diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 2ae87b96e..d14e2fcd4 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,6 +1,25 @@ import json from enum import Enum from dataclasses import asdict +from abc import ABC +from typing import Any + + +class JsonSerializableMixin(ABC): + """Mixin class to provide JSON serialization capabilities to dataclasses.""" + + def to_json(self) -> str: + """ + Convert the object to a JSON string, excluding None values. + Handles Enum serialization and filters out None values from the output. + """ + return json.dumps( + asdict( + self, + dict_factory=lambda data: {k: v for k, v in data if v is not None}, + ), + cls=EnumEncoder, + ) class EnumEncoder(json.JSONEncoder): @@ -14,16 +33,3 @@ def default(self, obj): if isinstance(obj, Enum): return obj.value return super().default(obj) - - -def to_json_compact(dataclass_obj): - """ - Convert a dataclass to JSON string, excluding None values. - """ - return json.dumps( - asdict( - dataclass_obj, - dict_factory=lambda data: {k: v for k, v in data if v is not None}, - ), - cls=EnumEncoder, - ) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 97b8f276b..35eba8157 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -117,7 +117,7 @@ def test_export_initial_telemetry_log( client = telemetry_client_setup["client"] host_url = telemetry_client_setup["host_url"] - client.export_event = MagicMock() + client._export_event = MagicMock() driver_connection_params = DriverConnectionParameters( http_path="test-path", @@ -131,7 +131,7 @@ def test_export_initial_telemetry_log( client.export_initial_telemetry_log(driver_connection_params, user_agent) mock_frontend_log.assert_called_once() - client.export_event.assert_called_once_with(mock_frontend_log.return_value) + client._export_event.assert_called_once_with(mock_frontend_log.return_value) @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") @@ -155,7 +155,7 @@ def test_export_failure_log( mock_frontend_log.return_value = MagicMock() client = telemetry_client_setup["client"] - client.export_event = MagicMock() + client._export_event = MagicMock() client._driver_connection_params = "test-connection-params" client._user_agent = "test-user-agent" @@ -172,7 +172,7 @@ def test_export_failure_log( mock_frontend_log.assert_called_once() - client.export_event.assert_called_once_with(mock_frontend_log.return_value) + client._export_event.assert_called_once_with(mock_frontend_log.return_value) def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" From 8924835e59aed8df7922903fe187485d4d976aee Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 13 Jun 2025 16:48:57 +0530 Subject: [PATCH 18/86] isdataclass check in JsonSerializableMixin Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index d14e2fcd4..6d95526b8 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,8 +1,7 @@ import json from enum import Enum -from dataclasses import asdict +from dataclasses import asdict, is_dataclass from abc import ABC -from typing import Any class JsonSerializableMixin(ABC): @@ -13,6 +12,11 @@ def to_json(self) -> str: Convert the object to a JSON string, excluding None values. Handles Enum serialization and filters out None values from the output. """ + if not is_dataclass(self): + raise TypeError( + f"{self.__class__.__name__} must be a dataclass to use JsonSerializableMixin" + ) + return json.dumps( asdict( self, From 65361e76f32e3199c94e809798169b6b9fe29c72 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 11:10:20 +0530 Subject: [PATCH 19/86] convert TelemetryClientFactory to module-level functions, replace NoopTelemetryClient class with NOOP_TELEMETRY_CLIENT singleton, updated tests accordingly Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 7 +- src/databricks/sql/exc.py | 6 +- .../sql/telemetry/telemetry_client.py | 243 ++++++++---------- tests/unit/test_telemetry.py | 162 ++++-------- 4 files changed, 167 insertions(+), 251 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 04ef4584f..bee60d317 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -53,8 +53,9 @@ TOperationState, ) from databricks.sql.telemetry.telemetry_client import ( - TelemetryClientFactory, TelemetryHelper, + initialize_telemetry_client, + get_telemetry_client, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -306,14 +307,14 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) - TelemetryClientFactory.initialize_telemetry_client( + initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, connection_uuid=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) - self._telemetry_client = TelemetryClientFactory.get_telemetry_client( + self._telemetry_client = get_telemetry_client( connection_uuid=self.get_session_id_hex() ) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index cc7a47cb4..443d5605f 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -2,7 +2,7 @@ import logging import traceback -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.telemetry_client import get_telemetry_client logger = logging.getLogger(__name__) @@ -22,9 +22,7 @@ def __init__( error_name = self.__class__.__name__ if connection_uuid: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - connection_uuid - ) + telemetry_client = get_telemetry_client(connection_uuid) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 403379aff..728220789 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -115,27 +115,16 @@ def close(self): pass -class NoopTelemetryClient(BaseTelemetryClient): - """ - NoopTelemetryClient is a telemetry client that does not send any events to the server. - It is used when telemetry is disabled. - """ - - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super(NoopTelemetryClient, cls).__new__(cls) - return cls._instance - - def export_initial_telemetry_log(self, driver_connection_params, user_agent): - pass - - def export_failure_log(self, error_name, error_message): - pass - - def close(self): - pass +# A single instance of the no-op client that can be reused +NOOP_TELEMETRY_CLIENT = type( + "NoopTelemetryClient", + (BaseTelemetryClient,), + { + "export_initial_telemetry_log": lambda self, *args, **kwargs: None, + "export_failure_log": lambda self, *args, **kwargs: None, + "close": lambda self: None, + }, +)() class TelemetryClient(BaseTelemetryClient): @@ -301,129 +290,111 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) self._flush() - TelemetryClientFactory.close(self._connection_uuid) - - -class TelemetryClientFactory: - """ - Static factory class for creating and managing telemetry clients. - It uses a thread pool to handle asynchronous operations. - """ - - _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of connection_uuid -> BaseTelemetryClient - _executor: Optional[ThreadPoolExecutor] = None - _initialized: bool = False - _lock = threading.Lock() # Thread safety for factory operations - _original_excepthook = None - _excepthook_installed = False - - @classmethod - def _initialize(cls): - """Initialize the factory if not already initialized""" - with cls._lock: - if not cls._initialized: - cls._clients = {} - cls._executor = ThreadPoolExecutor( - max_workers=10 - ) # Thread pool for async operations TODO: Decide on max workers - cls._install_exception_hook() - cls._initialized = True - logger.debug( - "TelemetryClientFactory initialized with thread pool (max_workers=10)" - ) - - @classmethod - def _install_exception_hook(cls): - """Install global exception handler for unhandled exceptions""" - if not cls._excepthook_installed: - cls._original_excepthook = sys.excepthook - sys.excepthook = cls._handle_unhandled_exception - cls._excepthook_installed = True - logger.debug("Global exception handler installed for telemetry") + _remove_telemetry_client(self._connection_uuid) + + +# Module-level state +_clients: Dict[str, BaseTelemetryClient] = {} +_executor: Optional[ThreadPoolExecutor] = None +_initialized: bool = False +_lock = threading.Lock() +_original_excepthook = None +_excepthook_installed = False + + +def _initialize(): + """Initialize the telemetry system if not already initialized""" + global _initialized, _executor + with _lock: + if not _initialized: + _clients.clear() + _executor = ThreadPoolExecutor(max_workers=10) + _install_exception_hook() + _initialized = True + logger.debug( + "Telemetry system initialized with thread pool (max_workers=10)" + ) - @classmethod - def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): - """Handle unhandled exceptions by sending telemetry and flushing thread pool""" - logger.debug("Handling unhandled exception: %s", exc_type.__name__) - clients_to_close = list(cls._clients.values()) - for client in clients_to_close: - client.close() +def _install_exception_hook(): + """Install global exception handler for unhandled exceptions""" + global _excepthook_installed, _original_excepthook + if not _excepthook_installed: + _original_excepthook = sys.excepthook + sys.excepthook = _handle_unhandled_exception + _excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") - # Call the original exception handler to maintain normal behavior - if cls._original_excepthook: - cls._original_excepthook(exc_type, exc_value, exc_traceback) - @staticmethod - def initialize_telemetry_client( - telemetry_enabled, - connection_uuid, - auth_provider, - host_url, - ): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" - try: - TelemetryClientFactory._initialize() +def _handle_unhandled_exception(exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) - with TelemetryClientFactory._lock: - if connection_uuid not in TelemetryClientFactory._clients: - logger.debug( - "Creating new TelemetryClient for connection %s", - connection_uuid, - ) - if telemetry_enabled: - TelemetryClientFactory._clients[ - connection_uuid - ] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, - ) - else: - TelemetryClientFactory._clients[ - connection_uuid - ] = NoopTelemetryClient() - except Exception as e: - logger.debug("Failed to initialize telemetry client: %s", e) - # Fallback to NoopTelemetryClient to ensure connection doesn't fail - TelemetryClientFactory._clients[connection_uuid] = NoopTelemetryClient() + clients_to_close = list(_clients.values()) + for client in clients_to_close: + client.close() - @staticmethod - def get_telemetry_client(connection_uuid): - """Get the telemetry client for a specific connection""" - try: - if connection_uuid in TelemetryClientFactory._clients: - return TelemetryClientFactory._clients[connection_uuid] - else: - logger.error( - "Telemetry client not initialized for connection %s", - connection_uuid, - ) - return NoopTelemetryClient() - except Exception as e: - logger.debug("Failed to get telemetry client: %s", e) - return NoopTelemetryClient() + # Call the original exception handler to maintain normal behavior + if _original_excepthook: + _original_excepthook(exc_type, exc_value, exc_traceback) - @staticmethod - def close(connection_uuid): - """Close and remove the telemetry client for a specific connection""" - with TelemetryClientFactory._lock: - if connection_uuid in TelemetryClientFactory._clients: - logger.debug( - "Removing telemetry client for connection %s", connection_uuid - ) - TelemetryClientFactory._clients.pop(connection_uuid, None) +def initialize_telemetry_client( + telemetry_enabled, connection_uuid, auth_provider, host_url +): + """Initialize a telemetry client for a specific connection if telemetry is enabled""" + try: + _initialize() - # Shutdown executor if no more clients - if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: + with _lock: + if connection_uuid not in _clients: logger.debug( - "No more telemetry clients, shutting down thread pool executor" + "Creating new TelemetryClient for connection %s", connection_uuid ) - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False + if telemetry_enabled: + _clients[connection_uuid] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, + executor=_executor, + ) + else: + _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + + +def get_telemetry_client(connection_uuid): + """Get the telemetry client for a specific connection""" + try: + if connection_uuid in _clients: + return _clients[connection_uuid] + else: + logger.error( + "Telemetry client not initialized for connection %s", connection_uuid + ) + return NOOP_TELEMETRY_CLIENT + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) + return NOOP_TELEMETRY_CLIENT + + +def _remove_telemetry_client(connection_uuid): + """Remove the telemetry client for a specific connection""" + global _initialized, _executor + with _lock: + if connection_uuid in _clients: + logger.debug("Removing telemetry client for connection %s", connection_uuid) + _clients.pop(connection_uuid, None) + + # Shutdown executor if no more clients + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 35eba8157..975febd20 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -5,8 +5,10 @@ from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, - NoopTelemetryClient, - TelemetryClientFactory, + NOOP_TELEMETRY_CLIENT, + initialize_telemetry_client, + get_telemetry_client, + _remove_telemetry_client, ) from databricks.sql.telemetry.models.enums import ( AuthMech, @@ -23,8 +25,8 @@ @pytest.fixture def noop_telemetry_client(): - """Fixture for NoopTelemetryClient.""" - return NoopTelemetryClient() + """Fixture for NOOP_TELEMETRY_CLIENT.""" + return NOOP_TELEMETRY_CLIENT @pytest.fixture @@ -53,30 +55,27 @@ def telemetry_client_setup(): @pytest.fixture -def telemetry_factory_reset(): - """Fixture to reset TelemetryClientFactory state before each test.""" - # Reset the static class state before each test - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False +def telemetry_system_reset(): + """Fixture to reset telemetry system state before each test.""" + # Reset the static state before each test + from databricks.sql.telemetry.telemetry_client import _clients, _executor, _initialized + _clients.clear() + if _executor: + _executor.shutdown(wait=True) + _executor = None + _initialized = False yield # Cleanup after test if needed - TelemetryClientFactory._clients = {} - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False + _clients.clear() + if _executor: + _executor.shutdown(wait=True) + _executor = None + _initialized = False class TestNoopTelemetryClient: - """Tests for the NoopTelemetryClient class.""" - - def test_singleton(self): - """Test that NoopTelemetryClient is a singleton.""" - client1 = NoopTelemetryClient() - client2 = NoopTelemetryClient() - assert client1 is client2 - + """Tests for the NOOP_TELEMETRY_CLIENT.""" + def test_export_initial_telemetry_log(self, noop_telemetry_client): """Test that export_initial_telemetry_log does nothing.""" noop_telemetry_client.export_initial_telemetry_log( @@ -249,8 +248,7 @@ def test_flush(self, telemetry_client_setup): client._send_telemetry.assert_called_once_with(["event1", "event2"]) assert client._events_batch == [] - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory") - def test_close(self, mock_factory_class, telemetry_client_setup): + def test_close(self, telemetry_client_setup): """Test closing the client.""" client = telemetry_client_setup["client"] client._flush = MagicMock() @@ -260,115 +258,63 @@ def test_close(self, mock_factory_class, telemetry_client_setup): client._flush.assert_called_once() -class TestTelemetryClientFactory: - """Tests for the TelemetryClientFactory static class.""" +class TestTelemetrySystem: + """Tests for the telemetry system functions.""" - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") - def test_initialize_telemetry_client_enabled(self, mock_client_class, telemetry_factory_reset): + def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is enabled.""" connection_uuid = "test-uuid" auth_provider = MagicMock() host_url = "test-host" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - connection_uuid=connection_uuid, - auth_provider=auth_provider, - host_url=host_url, - ) - # Verify a new client was created and stored - mock_client_class.assert_called_once_with( + initialize_telemetry_client( telemetry_enabled=True, connection_uuid=connection_uuid, auth_provider=auth_provider, host_url=host_url, - executor=TelemetryClientFactory._executor, ) - assert TelemetryClientFactory._clients[connection_uuid] == mock_client - - # Call again with the same connection_uuid - client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid=connection_uuid) - # Verify the same client was returned and no new client was created - assert client2 == mock_client - mock_client_class.assert_called_once() # Still only called once + client = get_telemetry_client(connection_uuid) + assert isinstance(client, TelemetryClient) + assert client._connection_uuid == connection_uuid + assert client._auth_provider == auth_provider + assert client._host_url == host_url - def test_initialize_telemetry_client_disabled(self, telemetry_factory_reset): + def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" connection_uuid = "test-uuid" - TelemetryClientFactory.initialize_telemetry_client( + initialize_telemetry_client( telemetry_enabled=False, connection_uuid=connection_uuid, auth_provider=MagicMock(), host_url="test-host", ) - # Verify a NoopTelemetryClient was stored - assert isinstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient) - - client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid) - assert isinstance(client2, NoopTelemetryClient) - - def test_get_telemetry_client_existing(self, telemetry_factory_reset): - """Test getting an existing telemetry client.""" - connection_uuid = "test-uuid" - mock_client = MagicMock() - TelemetryClientFactory._clients[connection_uuid] = mock_client - - client = TelemetryClientFactory.get_telemetry_client(connection_uuid) - - assert client == mock_client + client = get_telemetry_client(connection_uuid) + assert client is NOOP_TELEMETRY_CLIENT - def test_get_telemetry_client_nonexistent(self, telemetry_factory_reset): + def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" - client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") - - assert isinstance(client, NoopTelemetryClient) + client = get_telemetry_client("nonexistent-uuid") + assert client is NOOP_TELEMETRY_CLIENT - @patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient") - def test_close(self, mock_client_class, mock_executor_class, telemetry_factory_reset): - """Test that factory reinitializes properly after complete shutdown.""" - connection_uuid1 = "test-uuid1" - mock_executor1 = MagicMock() - mock_client1 = MagicMock() - mock_executor_class.return_value = mock_executor1 - mock_client_class.return_value = mock_client1 - - TelemetryClientFactory._clients[connection_uuid1] = mock_client1 - TelemetryClientFactory._executor = mock_executor1 - TelemetryClientFactory._initialized = True - - TelemetryClientFactory.close(connection_uuid1) - - assert TelemetryClientFactory._clients == {} - assert TelemetryClientFactory._executor is None - assert TelemetryClientFactory._initialized is False - mock_executor1.shutdown.assert_called_once_with(wait=True) - - # Now create a new client - this should reinitialize the factory - connection_uuid2 = "test-uuid2" - mock_executor2 = MagicMock() - mock_client2 = MagicMock() - mock_executor_class.return_value = mock_executor2 - mock_client_class.return_value = mock_client2 + def test_close_telemetry_client(self, telemetry_system_reset): + """Test closing a telemetry client.""" + connection_uuid = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" - TelemetryClientFactory.initialize_telemetry_client( + initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid2, - auth_provider=MagicMock(), - host_url="test-host", + connection_uuid=connection_uuid, + auth_provider=auth_provider, + host_url=host_url, ) - # Verify factory was reinitialized - assert TelemetryClientFactory._initialized is True - assert TelemetryClientFactory._executor is not None - assert TelemetryClientFactory._executor == mock_executor2 - assert connection_uuid2 in TelemetryClientFactory._clients - assert TelemetryClientFactory._clients[connection_uuid2] == mock_client2 + client = get_telemetry_client(connection_uuid) + assert isinstance(client, TelemetryClient) + + _remove_telemetry_client(connection_uuid) - # Verify new ThreadPoolExecutor was created - assert mock_executor_class.call_count == 1 \ No newline at end of file + client = get_telemetry_client(connection_uuid) + assert client is NOOP_TELEMETRY_CLIENT \ No newline at end of file From 1722a7799ed98b7dacc1cbc31b054a201fb6106b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 12:33:47 +0530 Subject: [PATCH 20/86] renamed connection_uuid as session_id_hex Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 40 +++++------ src/databricks/sql/exc.py | 6 +- .../sql/telemetry/telemetry_client.py | 50 +++++++------- src/databricks/sql/thrift_backend.py | 68 +++++++++---------- tests/unit/test_telemetry.py | 32 ++++----- 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index bee60d317..23e4e38b1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -309,13 +309,13 @@ def read(self) -> Optional[OAuthToken]: initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) self._telemetry_client = get_telemetry_client( - connection_uuid=self.get_session_id_hex() + session_id_hex=self.get_session_id_hex() ) driver_connection_params = DriverConnectionParameters( @@ -427,7 +427,7 @@ def cursor( if not self.open: raise InterfaceError( "Cannot create cursor from closed connection", - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), ) cursor = Cursor( @@ -480,7 +480,7 @@ def commit(self): def rollback(self): raise NotSupportedError( "Transactions are not supported on Databricks", - connection_uuid=self.get_session_id_hex(), + session_id_hex=self.get_session_id_hex(), ) @@ -535,7 +535,7 @@ def __iter__(self): else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def _determine_parameter_approach( @@ -675,7 +675,7 @@ def _check_not_closed(self): if not self.open: raise InterfaceError( "Attempting operation on closed cursor", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def _handle_staging_operation( @@ -695,7 +695,7 @@ def _handle_staging_operation( else: raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) abs_staging_allowed_local_paths = [ @@ -725,7 +725,7 @@ def _handle_staging_operation( if not allow_operation: raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) # May be real headers, or could be json string @@ -756,7 +756,7 @@ def _handle_staging_operation( raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def _handle_staging_put( @@ -770,7 +770,7 @@ def _handle_staging_put( if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "rb") as fh: @@ -789,7 +789,7 @@ def _handle_staging_put( if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) if r.status_code == ACCEPTED: @@ -809,7 +809,7 @@ def _handle_staging_get( if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) r = requests.get(url=presigned_url, headers=headers) @@ -819,7 +819,7 @@ def _handle_staging_get( if not r.ok: raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: @@ -835,7 +835,7 @@ def _handle_staging_remove( if not r.ok: raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def execute( @@ -1035,7 +1035,7 @@ def get_async_execution_result(self): else: raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def executemany(self, operation, seq_of_parameters): @@ -1187,7 +1187,7 @@ def fetchall(self) -> List[Row]: else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchone(self) -> Optional[Row]: @@ -1204,7 +1204,7 @@ def fetchone(self) -> Optional[Row]: else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchmany(self, size: int) -> List[Row]: @@ -1229,7 +1229,7 @@ def fetchmany(self, size: int) -> List[Row]: else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchall_arrow(self) -> "pyarrow.Table": @@ -1239,7 +1239,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def fetchmany_arrow(self, size) -> "pyarrow.Table": @@ -1249,7 +1249,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", - connection_uuid=self.connection.get_session_id_hex(), + session_id_hex=self.connection.get_session_id_hex(), ) def cancel(self) -> None: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 443d5605f..e7b2dad23 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -14,15 +14,15 @@ class Error(Exception): """ def __init__( - self, message=None, context=None, connection_uuid=None, *args, **kwargs + self, message=None, context=None, session_id_hex=None, *args, **kwargs ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} error_name = self.__class__.__name__ - if connection_uuid: - telemetry_client = get_telemetry_client(connection_uuid) + if session_id_hex: + telemetry_client = get_telemetry_client(session_id_hex) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 728220789..a918beb09 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -140,15 +140,15 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, telemetry_enabled, - connection_uuid, + session_id_hex, auth_provider, host_url, executor, ): - logger.debug("Initializing TelemetryClient for connection: %s", connection_uuid) + logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = 10 # TODO: Decide on batch size - self._connection_uuid = connection_uuid + self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None self._events_batch = [] @@ -159,7 +159,7 @@ def __init__( def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" - logger.debug("Exporting event for connection %s", self._connection_uuid) + logger.debug("Exporting event for connection %s", self._session_id_hex) with self._lock: self._events_batch.append(event) if len(self._events_batch) >= self._batch_size: @@ -230,7 +230,7 @@ def _telemetry_request_callback(self, future): def export_initial_telemetry_log(self, driver_connection_params, user_agent): logger.debug( - "Exporting initial telemetry log for connection %s", self._connection_uuid + "Exporting initial telemetry log for connection %s", self._session_id_hex ) try: @@ -247,7 +247,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): ), entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, + session_id=self._session_id_hex, system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, ) @@ -260,7 +260,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): logger.debug("Failed to export initial telemetry log: %s", e) def export_failure_log(self, error_name, error_message): - logger.debug("Exporting failure log for connection %s", self._connection_uuid) + logger.debug("Exporting failure log for connection %s", self._session_id_hex) try: error_info = DriverErrorInfo( error_name=error_name, stack_trace=error_message @@ -275,7 +275,7 @@ def export_failure_log(self, error_name, error_message): ), entry=FrontendLogEntry( sql_driver_log=TelemetryEvent( - session_id=self._connection_uuid, + session_id=self._session_id_hex, system_configuration=TelemetryHelper.get_driver_system_configuration(), driver_connection_params=self._driver_connection_params, error_info=error_info, @@ -288,9 +288,9 @@ def export_failure_log(self, error_name, error_message): def close(self): """Flush remaining events before closing""" - logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid) + logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() - _remove_telemetry_client(self._connection_uuid) + _remove_telemetry_client(self._session_id_hex) # Module-level state @@ -340,41 +340,41 @@ def _handle_unhandled_exception(exc_type, exc_value, exc_traceback): def initialize_telemetry_client( - telemetry_enabled, connection_uuid, auth_provider, host_url + telemetry_enabled, session_id_hex, auth_provider, host_url ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: _initialize() with _lock: - if connection_uuid not in _clients: + if session_id_hex not in _clients: logger.debug( - "Creating new TelemetryClient for connection %s", connection_uuid + "Creating new TelemetryClient for connection %s", session_id_hex ) if telemetry_enabled: - _clients[connection_uuid] = TelemetryClient( + _clients[session_id_hex] = TelemetryClient( telemetry_enabled=telemetry_enabled, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, executor=_executor, ) else: - _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - _clients[connection_uuid] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT -def get_telemetry_client(connection_uuid): +def get_telemetry_client(session_id_hex): """Get the telemetry client for a specific connection""" try: - if connection_uuid in _clients: - return _clients[connection_uuid] + if session_id_hex in _clients: + return _clients[session_id_hex] else: logger.error( - "Telemetry client not initialized for connection %s", connection_uuid + "Telemetry client not initialized for connection %s", session_id_hex ) return NOOP_TELEMETRY_CLIENT except Exception as e: @@ -382,13 +382,13 @@ def get_telemetry_client(connection_uuid): return NOOP_TELEMETRY_CLIENT -def _remove_telemetry_client(connection_uuid): +def _remove_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if connection_uuid in _clients: - logger.debug("Removing telemetry client for connection %s", connection_uuid) - _clients.pop(connection_uuid, None) + if session_id_hex in _clients: + logger.debug("Removing telemetry client for connection %s", session_id_hex) + _clients.pop(session_id_hex, None) # Shutdown executor if no more clients if not _clients and _executor: diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7c47da2b1..78683ac31 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -223,7 +223,7 @@ def __init__( raise self._request_lock = threading.RLock() - self._connection_uuid = None + self._session_id_hex = None # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): @@ -256,14 +256,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, connection_uuid=None): + def _check_response_for_error(response, session_id_hex=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) @staticmethod @@ -317,7 +317,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._connection_uuid, + self._session_id_hex, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -490,7 +490,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response, self._connection_uuid) + ThriftBackend._check_response_for_error(response, self._session_id_hex) return response error_info = response_or_error_info @@ -505,7 +505,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def _check_initial_namespace(self, catalog, schema, response): @@ -519,7 +519,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) if catalog: @@ -527,7 +527,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def _check_session_configuration(self, session_configuration): @@ -542,7 +542,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def open_session(self, session_configuration, catalog, schema): @@ -573,7 +573,7 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - self._connection_uuid = ( + self._session_id_hex = ( self.handle_to_hex_id(response.sessionHandle) if response.sessionHandle else None @@ -602,7 +602,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) else: raise ServerOperationError( @@ -612,7 +612,7 @@ def _check_command_not_in_error_or_closed_state( and self.guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -623,7 +623,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and self.guid_to_hex_id(op_handle.operationId.guid) }, - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) def _poll_for_status(self, op_handle): @@ -646,7 +646,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -655,7 +655,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None): + def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -687,7 +687,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) def convert_col(t_column_desc): @@ -698,7 +698,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, connection_uuid=None): + def _col_to_description(col, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -708,7 +708,7 @@ def _col_to_description(col, connection_uuid=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -722,7 +722,7 @@ def _col_to_description(col, connection_uuid=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, ) else: precision, scale = None, None @@ -730,9 +730,9 @@ def _col_to_description(col, connection_uuid=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, connection_uuid=None): + def _hive_schema_to_description(t_table_schema, session_id_hex=None): return [ - ThriftBackend._col_to_description(col, connection_uuid) + ThriftBackend._col_to_description(col, session_id_hex) for col in t_table_schema.columns ] @@ -754,7 +754,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -765,14 +765,14 @@ def _results_message_to_execute_response(self, resp, operation_state): ) description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, - self._connection_uuid, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema, self._session_id_hex ) .serialize() .to_pybytes() @@ -835,14 +835,14 @@ def get_execution_result(self, op_handle, cursor): has_more_rows = resp.hasMoreRows description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, - self._connection_uuid, + self._session_id_hex, ) if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._connection_uuid + t_result_set_metadata_resp.schema, self._session_id_hex ) .serialize() .to_pybytes() @@ -897,27 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState": return operation_state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None): + def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftBackend._check_response_for_error( t_spark_direct_results.operationStatus, - connection_uuid, + session_id_hex, ) if t_spark_direct_results.resultSetMetadata: ThriftBackend._check_response_for_error( t_spark_direct_results.resultSetMetadata, - connection_uuid, + session_id_hex, ) if t_spark_direct_results.resultSet: ThriftBackend._check_response_for_error( t_spark_direct_results.resultSet, - connection_uuid, + session_id_hex, ) if t_spark_direct_results.closeOperation: ThriftBackend._check_response_for_error( t_spark_direct_results.closeOperation, - connection_uuid, + session_id_hex, ) def execute_command( @@ -1066,7 +1066,7 @@ def get_columns( def _handle_execute_response(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1077,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor): def _handle_execute_response_async(self, resp, cursor): cursor.active_op_handle = resp.operationHandle - self._check_direct_results_for_error(resp.directResults, self._connection_uuid) + self._check_direct_results_for_error(resp.directResults, self._session_id_hex) def fetch_results( self, @@ -1112,7 +1112,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - connection_uuid=self._connection_uuid, + session_id_hex=self._session_id_hex, ) queue = ResultSetQueueFactory.build_queue( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 975febd20..f89191ca4 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -32,14 +32,14 @@ def noop_telemetry_client(): @pytest.fixture def telemetry_client_setup(): """Fixture for TelemetryClient setup data.""" - connection_uuid = str(uuid.uuid4()) + session_id_hex = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") host_url = "test-host" executor = MagicMock() client = TelemetryClient( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, executor=executor, @@ -47,7 +47,7 @@ def telemetry_client_setup(): return { "client": client, - "connection_uuid": connection_uuid, + "session_id_hex": session_id_hex, "auth_provider": auth_provider, "host_url": host_url, "executor": executor, @@ -217,7 +217,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup) unauthenticated_client = TelemetryClient( telemetry_enabled=True, - connection_uuid=str(uuid.uuid4()), + session_id_hex=str(uuid.uuid4()), auth_provider=None, # No auth provider host_url=host_url, executor=executor, @@ -263,34 +263,34 @@ class TestTelemetrySystem: def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is enabled.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" auth_provider = MagicMock() host_url = "test-host" initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) - assert client._connection_uuid == connection_uuid + assert client._session_id_hex == session_id_hex assert client._auth_provider == auth_provider assert client._host_url == host_url def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" initialize_telemetry_client( telemetry_enabled=False, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert client is NOOP_TELEMETRY_CLIENT def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): @@ -300,21 +300,21 @@ def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): def test_close_telemetry_client(self, telemetry_system_reset): """Test closing a telemetry client.""" - connection_uuid = "test-uuid" + session_id_hex = "test-uuid" auth_provider = MagicMock() host_url = "test-host" initialize_telemetry_client( telemetry_enabled=True, - connection_uuid=connection_uuid, + session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) - _remove_telemetry_client(connection_uuid) + _remove_telemetry_client(session_id_hex) - client = get_telemetry_client(connection_uuid) + client = get_telemetry_client(session_id_hex) assert client is NOOP_TELEMETRY_CLIENT \ No newline at end of file From e84143419a4ab67e88fcf317145fb74d0ed12a46 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 15:43:52 +0530 Subject: [PATCH 21/86] added NotImplementedError to abstract class, added unit tests Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 6 +- tests/unit/test_telemetry.py | 156 +++++++++++++++++- 2 files changed, 158 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index a918beb09..585945cc6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -104,15 +104,15 @@ class BaseTelemetryClient(ABC): @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): - pass + raise NotImplementedError("Subclasses must implement export_initial_telemetry_log") @abstractmethod def export_failure_log(self, error_name, error_message): - pass + raise NotImplementedError("Subclasses must implement export_failure_log") @abstractmethod def close(self): - pass + raise NotImplementedError("Subclasses must implement close") # A single instance of the no-op client that can be reused diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index f89191ca4..b82780f4b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -9,10 +9,13 @@ initialize_telemetry_client, get_telemetry_client, _remove_telemetry_client, + TelemetryHelper, + BaseTelemetryClient ) from databricks.sql.telemetry.models.enums import ( AuthMech, DatabricksClientType, + AuthFlow, ) from databricks.sql.telemetry.models.event import ( DriverConnectionParameters, @@ -20,6 +23,8 @@ ) from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, ) @@ -257,6 +262,72 @@ def test_close(self, telemetry_client_setup): client._flush.assert_called_once() + @patch("requests.post") + def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup): + """Test successful telemetry request callback.""" + client = telemetry_client_setup["client"] + + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_future = MagicMock() + mock_future.result.return_value = mock_response + + client._telemetry_request_callback(mock_future) + + mock_future.result.assert_called_once() + + @patch("requests.post") + def test_telemetry_request_callback_failure(self, mock_post, telemetry_client_setup): + """Test telemetry request callback with failure""" + client = telemetry_client_setup["client"] + + # Test with non-200 status code + mock_response = MagicMock() + mock_response.status_code = 500 + future = MagicMock() + future.result.return_value = mock_response + client._telemetry_request_callback(future) + + # Test with exception + future = MagicMock() + future.result.side_effect = Exception("Test error") + client._telemetry_request_callback(future) + + def test_telemetry_client_exception_handling(self, telemetry_client_setup): + """Test exception handling in telemetry client methods.""" + client = telemetry_client_setup["client"] + + # Test export_initial_telemetry_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_initial_telemetry_log(MagicMock(), "test-agent") + + # Test export_failure_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_failure_log("TestError", "Test error message") + + # Test _send_telemetry with exception + with patch.object(client._executor, 'submit', side_effect=Exception("Test error")): + # Should not raise exception + client._send_telemetry([MagicMock()]) + + def test_send_telemetry_thread_pool_failure(self, telemetry_client_setup): + """Test handling of thread pool submission failure""" + client = telemetry_client_setup["client"] + client._executor.submit.side_effect = Exception("Thread pool error") + event = MagicMock() + client._send_telemetry([event]) + + def test_base_telemetry_client_abstract_methods(self): + """Test that BaseTelemetryClient cannot be instantiated without implementing all abstract methods""" + class TestBaseClient(BaseTelemetryClient): + pass + + with pytest.raises(TypeError): + TestBaseClient() # Can't instantiate abstract class + class TestTelemetrySystem: """Tests for the telemetry system functions.""" @@ -317,4 +388,87 @@ def test_close_telemetry_client(self, telemetry_system_reset): _remove_telemetry_client(session_id_hex) client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT \ No newline at end of file + assert client is NOOP_TELEMETRY_CLIENT + + @patch("databricks.sql.telemetry.telemetry_client._handle_unhandled_exception") + def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): + """Test that global exception hook is installed and handles exceptions.""" + from databricks.sql.telemetry.telemetry_client import _install_exception_hook, _handle_unhandled_exception + + _install_exception_hook() + + test_exception = ValueError("Test exception") + _handle_unhandled_exception(type(test_exception), test_exception, None) + + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) + + +class TestTelemetryHelper: + """Tests for the TelemetryHelper class.""" + + def test_get_driver_system_configuration(self): + """Test getting driver system configuration.""" + config = TelemetryHelper.get_driver_system_configuration() + + assert isinstance(config.driver_name, str) + assert isinstance(config.driver_version, str) + assert isinstance(config.runtime_name, str) + assert isinstance(config.runtime_vendor, str) + assert isinstance(config.runtime_version, str) + assert isinstance(config.os_name, str) + assert isinstance(config.os_version, str) + assert isinstance(config.os_arch, str) + assert isinstance(config.locale_name, str) + assert isinstance(config.char_set_encoding, str) + + assert config.driver_name == "Databricks SQL Python Connector" + assert "Python" in config.runtime_name + assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] + assert config.os_name in ["Darwin", "Linux", "Windows"] + + # Verify caching behavior + config2 = TelemetryHelper.get_driver_system_configuration() + assert config is config2 # Should return same instance + + def test_get_auth_mechanism(self): + """Test getting auth mechanism for different auth providers.""" + # Test PAT auth + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT + + # Test OAuth auth + oauth_auth = MagicMock(spec=DatabricksOAuthProvider) + assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH + + # Test External auth + external_auth = MagicMock(spec=ExternalAuthProvider) + assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH + + # Test None auth provider + assert TelemetryHelper.get_auth_mechanism(None) is None + + # Test unknown auth provider + unknown_auth = MagicMock() + assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT + + def test_get_auth_flow(self): + """Test getting auth flow for different OAuth providers.""" + # Test OAuth with existing tokens + oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_tokens._access_token = "test-access-token" + oauth_with_tokens._refresh_token = "test-refresh-token" + assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH + + # Test OAuth with browser-based auth + oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_browser._access_token = None + oauth_with_browser._refresh_token = None + oauth_with_browser.oauth_manager = MagicMock() + assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + # Test non-OAuth provider + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_flow(pat_auth) is None + + # Test None auth provider + assert TelemetryHelper.get_auth_flow(None) is None \ No newline at end of file From 2f89266cd2d44745fcf4a528aa07b8b2050299a8 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 16 Jun 2025 15:45:42 +0530 Subject: [PATCH 22/86] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 585945cc6..a918beb09 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -104,15 +104,15 @@ class BaseTelemetryClient(ABC): @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): - raise NotImplementedError("Subclasses must implement export_initial_telemetry_log") + pass @abstractmethod def export_failure_log(self, error_name, error_message): - raise NotImplementedError("Subclasses must implement export_failure_log") + pass @abstractmethod def close(self): - raise NotImplementedError("Subclasses must implement close") + pass # A single instance of the no-op client that can be reused From 5564bbb9a0c40123caf9c13058fbc7b030bea44a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 10:31:00 +0530 Subject: [PATCH 23/86] added PEP-249 link, changed NoopTelemetryClient implementation Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 1 + .../sql/telemetry/telemetry_client.py | 32 ++++++++++++------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index e7b2dad23..20a898999 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -7,6 +7,7 @@ logger = logging.getLogger(__name__) ### PEP-249 Mandated ### +# https://peps.python.org/pep-0249/#exceptions class Error(Exception): """Base class for DB-API2.0 exceptions. `message`: An optional user-friendly error message. It should be short, actionable and stable diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index a918beb09..06a6813c1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -104,27 +104,37 @@ class BaseTelemetryClient(ABC): @abstractmethod def export_initial_telemetry_log(self, driver_connection_params, user_agent): - pass + raise NotImplementedError( + "Subclasses must implement export_initial_telemetry_log" + ) @abstractmethod def export_failure_log(self, error_name, error_message): - pass + raise NotImplementedError("Subclasses must implement export_failure_log") @abstractmethod + def close(self): + raise NotImplementedError("Subclasses must implement close") + + +class NoopTelemetryClient(BaseTelemetryClient): + """ + NoopTelemetryClient is a telemetry client that does not send any events to the server. + It is used when telemetry is disabled. + """ + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + pass + + def export_failure_log(self, error_name, error_message): + pass + def close(self): pass # A single instance of the no-op client that can be reused -NOOP_TELEMETRY_CLIENT = type( - "NoopTelemetryClient", - (BaseTelemetryClient,), - { - "export_initial_telemetry_log": lambda self, *args, **kwargs: None, - "export_failure_log": lambda self, *args, **kwargs: None, - "close": lambda self: None, - }, -)() +NOOP_TELEMETRY_CLIENT = NoopTelemetryClient() class TelemetryClient(BaseTelemetryClient): From 1e4e8cfb07dc2ecef5b73ba36bde487b4aaec965 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 12:10:28 +0530 Subject: [PATCH 24/86] removed unused import Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 20a898999..9ca662126 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,6 +1,5 @@ import json import logging -import traceback from databricks.sql.telemetry.telemetry_client import get_telemetry_client From 55b29bceafcc01541e77f9d69264d214656c60e2 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 14:45:43 +0530 Subject: [PATCH 25/86] made telemetry client close a module-level function Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 3 +- .../sql/telemetry/telemetry_client.py | 6 ++-- tests/unit/test_telemetry.py | 30 +++++++++++++++++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 23e4e38b1..9359c4272 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -56,6 +56,7 @@ TelemetryHelper, initialize_telemetry_client, get_telemetry_client, + close_telemetry_client, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -471,7 +472,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - self._telemetry_client.close() + close_telemetry_client(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 06a6813c1..1fd850dbc 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -300,7 +300,6 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() - _remove_telemetry_client(self._session_id_hex) # Module-level state @@ -392,13 +391,14 @@ def get_telemetry_client(session_id_hex): return NOOP_TELEMETRY_CLIENT -def _remove_telemetry_client(session_id_hex): +def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: if session_id_hex in _clients: logger.debug("Removing telemetry client for connection %s", session_id_hex) - _clients.pop(session_id_hex, None) + telemetry_client = _clients.pop(session_id_hex, None) + telemetry_client.close() # Shutdown executor if no more clients if not _clients and _executor: diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index b82780f4b..217f4cfaa 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -8,7 +8,7 @@ NOOP_TELEMETRY_CLIENT, initialize_telemetry_client, get_telemetry_client, - _remove_telemetry_client, + close_telemetry_client, TelemetryHelper, BaseTelemetryClient ) @@ -385,7 +385,33 @@ def test_close_telemetry_client(self, telemetry_system_reset): client = get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) - _remove_telemetry_client(session_id_hex) + client.close = MagicMock() + + close_telemetry_client(session_id_hex) + + client.close.assert_called_once() + + client = get_telemetry_client(session_id_hex) + assert client is NOOP_TELEMETRY_CLIENT + + def test_close_telemetry_client_noop(self, telemetry_system_reset): + """Test closing a no-op telemetry client.""" + session_id_hex = "test-uuid" + initialize_telemetry_client( + telemetry_enabled=False, + session_id_hex=session_id_hex, + auth_provider=MagicMock(), + host_url="test-host", + ) + + client = get_telemetry_client(session_id_hex) + assert client is NOOP_TELEMETRY_CLIENT + + client.close = MagicMock() + + close_telemetry_client(session_id_hex) + + client.close.assert_called_once() client = get_telemetry_client(session_id_hex) assert client is NOOP_TELEMETRY_CLIENT From 93bf170f46cda84f8658ef4149ed84328b14778e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 15:15:47 +0530 Subject: [PATCH 26/86] unit tests verbose Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 462d22369..265f8a829 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit -v run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit -v check-linting: runs-on: ubuntu-latest strategy: From 45f5ccf0e97196b1232397756e1faecf67e98a9c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 15:34:55 +0530 Subject: [PATCH 27/86] debug logs in unit tests Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 265f8a829..d78854671 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v + run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v + run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v check-linting: runs-on: ubuntu-latest strategy: From 8ff1c1fa26dc3b0d62419e21faf2ffa03d136f9e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 15:57:42 +0530 Subject: [PATCH 28/86] debug logs in unit tests Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index d78854671..7a221c53a 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v + run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: LOG_LEVEL=DEBUG poetry run python -m pytest tests/unit -v + run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG check-linting: runs-on: ubuntu-latest strategy: From 8bdd3243bb4394a7aa9a1ea930f27142b30b6d1e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 18:40:34 +0530 Subject: [PATCH 29/86] removed ABC from mixin, added try/catch block around executor shutdown Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 4 +- .../sql/telemetry/telemetry_client.py | 36 +++++++++------- src/databricks/sql/telemetry/utils.py | 3 +- tests/unit/test_client.py | 41 ++++++++++--------- tests/unit/test_telemetry.py | 24 +++++++---- 5 files changed, 61 insertions(+), 47 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 7a221c53a..e23072d3e 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG + run: poetry run python -m pytest tests/unit run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s --log-cli-level=DEBUG + run: poetry run python -m pytest tests/unit -v -s check-linting: runs-on: ubuntu-latest strategy: diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 1fd850dbc..7d6f7b404 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -123,6 +123,13 @@ class NoopTelemetryClient(BaseTelemetryClient): It is used when telemetry is disabled. """ + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NoopTelemetryClient, cls).__new__(cls) + return cls._instance + def export_initial_telemetry_log(self, driver_connection_params, user_agent): pass @@ -133,10 +140,6 @@ def close(self): pass -# A single instance of the no-op client that can be reused -NOOP_TELEMETRY_CLIENT = NoopTelemetryClient() - - class TelemetryClient(BaseTelemetryClient): """ Telemetry client class that handles sending telemetry events in batches to the server. @@ -369,11 +372,11 @@ def initialize_telemetry_client( executor=_executor, ) else: - _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NoopTelemetryClient() except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - _clients[session_id_hex] = NOOP_TELEMETRY_CLIENT + _clients[session_id_hex] = NoopTelemetryClient() def get_telemetry_client(session_id_hex): @@ -385,10 +388,10 @@ def get_telemetry_client(session_id_hex): logger.error( "Telemetry client not initialized for connection %s", session_id_hex ) - return NOOP_TELEMETRY_CLIENT + return NoopTelemetryClient() except Exception as e: logger.debug("Failed to get telemetry client: %s", e) - return NOOP_TELEMETRY_CLIENT + return NoopTelemetryClient() def close_telemetry_client(session_id_hex): @@ -401,10 +404,13 @@ def close_telemetry_client(session_id_hex): telemetry_client.close() # Shutdown executor if no more clients - if not _clients and _executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False + try: + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) diff --git a/src/databricks/sql/telemetry/utils.py b/src/databricks/sql/telemetry/utils.py index 6d95526b8..df7acf28c 100644 --- a/src/databricks/sql/telemetry/utils.py +++ b/src/databricks/sql/telemetry/utils.py @@ -1,10 +1,9 @@ import json from enum import Enum from dataclasses import asdict, is_dataclass -from abc import ABC -class JsonSerializableMixin(ABC): +class JsonSerializableMixin: """Mixin class to provide JSON serialization capabilities to dataclasses.""" def to_json(self) -> str: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 588b0d70e..427a7d7bd 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -337,6 +337,7 @@ def test_negative_fetch_throws_exception(self): result_set.fetchmany(-1) def test_context_manager_closes_cursor(self): + print("hellow") mock_close = Mock() with client.Cursor(Mock(), Mock()) as cursor: cursor.close = mock_close @@ -351,29 +352,30 @@ def test_context_manager_closes_cursor(self): finally: cursor.close.assert_called() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value + # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + # def test_context_manager_closes_connection(self, mock_client_class): + # print("hellow1") + # instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + # mock_open_session_resp.sessionHandle.sessionId = b"\x22" + # instance.open_session.return_value = mock_open_session_resp - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass + # with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + # pass - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # # Check the close session request has an id of x22 + # close_session_id = instance.close_session.call_args[0][0].sessionId + # self.assertEqual(close_session_id, b"\x22") - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close = Mock() - try: - with self.assertRaises(KeyboardInterrupt): - with connection: - raise KeyboardInterrupt("Simulated interrupt") - finally: - connection.close.assert_called() + # connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + # connection.close = Mock() + # try: + # with self.assertRaises(KeyboardInterrupt): + # with connection: + # raise KeyboardInterrupt("Simulated interrupt") + # finally: + # connection.close.assert_called() def dict_product(self, dicts): """ @@ -791,6 +793,7 @@ def test_cursor_context_manager_handles_exit_exception(self): def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" + print("banana") cursors_closed = [] def mock_close_with_exception(): diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 217f4cfaa..84833dafd 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -5,7 +5,7 @@ from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, - NOOP_TELEMETRY_CLIENT, + NoopTelemetryClient, initialize_telemetry_client, get_telemetry_client, close_telemetry_client, @@ -30,8 +30,8 @@ @pytest.fixture def noop_telemetry_client(): - """Fixture for NOOP_TELEMETRY_CLIENT.""" - return NOOP_TELEMETRY_CLIENT + """Fixture for NoopTelemetryClient.""" + return NoopTelemetryClient() @pytest.fixture @@ -79,7 +79,13 @@ def telemetry_system_reset(): class TestNoopTelemetryClient: - """Tests for the NOOP_TELEMETRY_CLIENT.""" + """Tests for the NoopTelemetryClient.""" + + def test_singleton(self): + """Test that NoopTelemetryClient is a singleton.""" + client1 = NoopTelemetryClient() + client2 = NoopTelemetryClient() + assert client1 is client2 def test_export_initial_telemetry_log(self, noop_telemetry_client): """Test that export_initial_telemetry_log does nothing.""" @@ -362,12 +368,12 @@ def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): ) client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" client = get_telemetry_client("nonexistent-uuid") - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client(self, telemetry_system_reset): """Test closing a telemetry client.""" @@ -392,7 +398,7 @@ def test_close_telemetry_client(self, telemetry_system_reset): client.close.assert_called_once() client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client_noop(self, telemetry_system_reset): """Test closing a no-op telemetry client.""" @@ -405,7 +411,7 @@ def test_close_telemetry_client_noop(self, telemetry_system_reset): ) client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) client.close = MagicMock() @@ -414,7 +420,7 @@ def test_close_telemetry_client_noop(self, telemetry_system_reset): client.close.assert_called_once() client = get_telemetry_client(session_id_hex) - assert client is NOOP_TELEMETRY_CLIENT + assert isinstance(client, NoopTelemetryClient) @patch("databricks.sql.telemetry.telemetry_client._handle_unhandled_exception") def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): From f99f7ea98f1385c07855f0e139e023da049a1cc8 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:30:53 +0530 Subject: [PATCH 30/86] checking stuff Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- tests/unit/test_client.py | 49 +++++++++++------------ 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index e23072d3e..f5b23dd28 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -75,7 +75,7 @@ jobs: uses: actions/checkout@v2 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 427a7d7bd..d64b06b5f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -52,7 +52,6 @@ def new(cls): ) ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp - return ThriftBackendMock @classmethod @@ -352,30 +351,30 @@ def test_context_manager_closes_cursor(self): finally: cursor.close.assert_called() - # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - # def test_context_manager_closes_connection(self, mock_client_class): - # print("hellow1") - # instance = mock_client_class.return_value - - # mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - # mock_open_session_resp.sessionHandle.sessionId = b"\x22" - # instance.open_session.return_value = mock_open_session_resp - - # with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - # pass - - # # Check the close session request has an id of x22 - # close_session_id = instance.close_session.call_args[0][0].sessionId - # self.assertEqual(close_session_id, b"\x22") - - # connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - # connection.close = Mock() - # try: - # with self.assertRaises(KeyboardInterrupt): - # with connection: - # raise KeyboardInterrupt("Simulated interrupt") - # finally: - # connection.close.assert_called() + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + print("hellow1") + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close = Mock() + try: + with self.assertRaises(KeyboardInterrupt): + with connection: + raise KeyboardInterrupt("Simulated interrupt") + finally: + connection.close.assert_called() def dict_product(self, dicts): """ From b972c8a36e29414d832ee5806d4ccedf9e98dce3 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:35:25 +0530 Subject: [PATCH 31/86] finding out --- .github/workflows/code-quality-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index f5b23dd28..8f8a2278a 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -117,7 +117,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t", "3.14"] steps: #---------------------------------------------- # check-out repo and set-up python From 7ca36363e3d5c324287def891f9a261189d4931a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:37:20 +0530 Subject: [PATCH 32/86] finding out more --- .github/workflows/code-quality-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 8f8a2278a..158ac64a6 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -117,7 +117,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t", "3.14"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t"] steps: #---------------------------------------------- # check-out repo and set-up python From 0ac8ed2d7ad3904c5a1312a4489a7b913ab35a74 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:49:31 +0530 Subject: [PATCH 33/86] more more finding out more nice --- src/databricks/sql/telemetry/telemetry_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 7d6f7b404..9d1a38909 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -183,9 +183,9 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + # with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) From c457a0970d99604347d40478a7f1ef3f3e53674d Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 20:53:19 +0530 Subject: [PATCH 34/86] locks are useless anyways --- .../sql/telemetry/telemetry_client.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9d1a38909..9176cdcaf 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -183,9 +183,9 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - # with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) @@ -397,20 +397,20 @@ def get_telemetry_client(session_id_hex): def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor - with _lock: - if session_id_hex in _clients: - logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client = _clients.pop(session_id_hex, None) - telemetry_client.close() - # Shutdown executor if no more clients - try: - if not _clients and _executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False - except Exception as e: - logger.debug("Failed to shutdown thread pool executor: %s", e) + if session_id_hex in _clients: + logger.debug("Removing telemetry client for connection %s", session_id_hex) + telemetry_client = _clients.pop(session_id_hex, None) + telemetry_client.close() + + # Shutdown executor if no more clients + try: + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) From 5f07a84bcc54aa28c117a8fe5771a5302a019ea5 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:00:51 +0530 Subject: [PATCH 35/86] haha --- .../sql/telemetry/telemetry_client.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9176cdcaf..754772235 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -317,15 +317,14 @@ def close(self): def _initialize(): """Initialize the telemetry system if not already initialized""" global _initialized, _executor - with _lock: - if not _initialized: - _clients.clear() - _executor = ThreadPoolExecutor(max_workers=10) - _install_exception_hook() - _initialized = True - logger.debug( - "Telemetry system initialized with thread pool (max_workers=10)" - ) + if not _initialized: + _clients.clear() + _executor = ThreadPoolExecutor(max_workers=10) + _install_exception_hook() + _initialized = True + logger.debug( + "Telemetry system initialized with thread pool (max_workers=10)" + ) def _install_exception_hook(): @@ -356,9 +355,8 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - _initialize() - with _lock: + _initialize() if session_id_hex not in _clients: logger.debug( "Creating new TelemetryClient for connection %s", session_id_hex @@ -371,8 +369,10 @@ def initialize_telemetry_client( host_url=host_url, executor=_executor, ) + print("i have initialized the telemetry client yes") else: _clients[session_id_hex] = NoopTelemetryClient() + print("i have initialized the noop client yes") except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail @@ -397,20 +397,20 @@ def get_telemetry_client(session_id_hex): def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor + with _lock: + if session_id_hex in _clients: + logger.debug("Removing telemetry client for connection %s", session_id_hex) + telemetry_client = _clients.pop(session_id_hex, None) + telemetry_client.close() - if session_id_hex in _clients: - logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client = _clients.pop(session_id_hex, None) - telemetry_client.close() - - # Shutdown executor if no more clients - try: - if not _clients and _executor: - logger.debug( - "No more telemetry clients, shutting down thread pool executor" - ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False - except Exception as e: - logger.debug("Failed to shutdown thread pool executor: %s", e) + # Shutdown executor if no more clients + try: + if not _clients and _executor: + logger.debug( + "No more telemetry clients, shutting down thread pool executor" + ) + _executor.shutdown(wait=True) + _executor = None + _initialized = False + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) From 1115e2523f1419016aa41b2a710867eb488bacf7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:07:03 +0530 Subject: [PATCH 36/86] normal --- tests/unit/test_client.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d64b06b5f..981543552 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -497,17 +497,17 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) + # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + # def test_configuration_passthrough(self, mock_client_class): + # mock_session_config = Mock() + # databricks.sql.connect( + # session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + # ) + + # self.assertEqual( + # mock_client_class.return_value.open_session.call_args[0][0], + # mock_session_config, + # ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): From de1ed87b6b4b17e9cadf0ac9b961bb4c2f730704 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:17:32 +0530 Subject: [PATCH 37/86] := looks like walrus horizontally --- src/databricks/sql/telemetry/telemetry_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 754772235..c194f9f90 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,9 +398,8 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if session_id_hex in _clients: + if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client = _clients.pop(session_id_hex, None) telemetry_client.close() # Shutdown executor if no more clients From 554aeaf02ef36628797b99370464b15c31c56144 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:35:16 +0530 Subject: [PATCH 38/86] one more --- .../sql/telemetry/telemetry_client.py | 4 +++- tests/unit/test_client.py | 22 +++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c194f9f90..e36f05ea7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,7 +398,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 981543552..d64b06b5f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -497,17 +497,17 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - # @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - # def test_configuration_passthrough(self, mock_client_class): - # mock_session_config = Mock() - # databricks.sql.connect( - # session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - # ) - - # self.assertEqual( - # mock_client_class.return_value.open_session.call_args[0][0], - # mock_session_config, - # ) + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): From fffac5f70ebfecfcb55fc330cc6ebeab67a2c5e0 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:38:48 +0530 Subject: [PATCH 39/86] walrus again --- src/databricks/sql/telemetry/telemetry_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e36f05ea7..c194f9f90 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,9 +398,7 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) + if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From b77208a8231fccec5b03590dc1d5cb6fbcfb5418 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:42:46 +0530 Subject: [PATCH 40/86] old stuff without walrus seems to fail --- src/databricks/sql/telemetry/telemetry_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c194f9f90..e36f05ea7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -398,7 +398,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From 733c288e36854398a03baa57060493ac1cb6474a Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 21:44:45 +0530 Subject: [PATCH 41/86] manually do the walrussing --- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e36f05ea7..af32de489 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -399,8 +399,8 @@ def close_telemetry_client(session_id_hex): global _initialized, _executor with _lock: # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) + telemetry_client = _clients.pop(session_id_hex, None) + if telemetry_client is not None: logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From ca8b9586e360e6a70de88657a76f952f0d3cdd71 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:36:57 +0530 Subject: [PATCH 42/86] change 3.13t, v2 Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 6 +++--- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 158ac64a6..df6a0e169 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -61,7 +61,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit run-unit-tests-with-arrow: runs-on: ubuntu-latest strategy: @@ -75,7 +75,7 @@ jobs: uses: actions/checkout@v2 - name: Set up python ${{ matrix.python-version }} id: setup-python - uses: actions/setup-python@v5 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} #---------------------------------------------- @@ -117,7 +117,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13t"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: #---------------------------------------------- # check-out repo and set-up python diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index af32de489..e36f05ea7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -399,8 +399,8 @@ def close_telemetry_client(session_id_hex): global _initialized, _executor with _lock: # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - telemetry_client = _clients.pop(session_id_hex, None) - if telemetry_client is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From 3eabac9641b893927ca417e2ad2d8007914c4455 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:45:24 +0530 Subject: [PATCH 43/86] formatting, added walrus Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e36f05ea7..5ff1a63d9 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -322,9 +322,7 @@ def _initialize(): _executor = ThreadPoolExecutor(max_workers=10) _install_exception_hook() _initialized = True - logger.debug( - "Telemetry system initialized with thread pool (max_workers=10)" - ) + logger.debug("Telemetry system initialized with thread pool (max_workers=10)") def _install_exception_hook(): @@ -398,9 +396,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) + if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + # if session_id_hex in _clients: + # telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From fb9ef43b3857625d73977ff98c70992cbc9dcc12 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:55:49 +0530 Subject: [PATCH 44/86] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 5ff1a63d9..c10dd4083 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -397,8 +397,8 @@ def close_telemetry_client(session_id_hex): global _initialized, _executor with _lock: if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - # if session_id_hex in _clients: - # telemetry_client = _clients.pop(session_id_hex) + # if session_id_hex in _clients: + # telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() From 1e795aa4f7a5bb2e12e79748a1f76fc368bc8645 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 17 Jun 2025 23:59:53 +0530 Subject: [PATCH 45/86] removed walrus, removed test before stalling test Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 6 +-- tests/unit/test_client.py | 54 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c10dd4083..c7daff14f 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -396,9 +396,9 @@ def close_telemetry_client(session_id_hex): """Remove the telemetry client for a specific connection""" global _initialized, _executor with _lock: - if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - # if session_id_hex in _clients: - # telemetry_client = _clients.pop(session_id_hex) + # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: + if session_id_hex in _clients: + telemetry_client = _clients.pop(session_id_hex) logger.debug("Removing telemetry client for connection %s", session_id_hex) telemetry_client.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d64b06b5f..cc41e6c87 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -790,41 +790,41 @@ def test_cursor_context_manager_handles_exit_exception(self): cursor.close.assert_called_once() - def test_connection_close_handles_cursor_close_exception(self): - """Test that _close handles exceptions from cursor.close() properly.""" - print("banana") - cursors_closed = [] + # def test_connection_close_handles_cursor_close_exception(self): + # """Test that _close handles exceptions from cursor.close() properly.""" + # print("banana") + # cursors_closed = [] - def mock_close_with_exception(): - cursors_closed.append(1) - raise Exception("Test error during close") + # def mock_close_with_exception(): + # cursors_closed.append(1) + # raise Exception("Test error during close") - cursor1 = Mock() - cursor1.close = mock_close_with_exception + # cursor1 = Mock() + # cursor1.close = mock_close_with_exception - def mock_close_normal(): - cursors_closed.append(2) + # def mock_close_normal(): + # cursors_closed.append(2) - cursor2 = Mock() - cursor2.close = mock_close_normal + # cursor2 = Mock() + # cursor2.close = mock_close_normal - mock_backend = Mock() - mock_session_handle = Mock() + # mock_backend = Mock() + # mock_session_handle = Mock() - try: - for cursor in [cursor1, cursor2]: - try: - cursor.close() - except Exception: - pass + # try: + # for cursor in [cursor1, cursor2]: + # try: + # cursor.close() + # except Exception: + # pass - mock_backend.close_session(mock_session_handle) - except Exception as e: - self.fail(f"Connection close should handle exceptions: {e}") + # mock_backend.close_session(mock_session_handle) + # except Exception as e: + # self.fail(f"Connection close should handle exceptions: {e}") - self.assertEqual( - cursors_closed, [1, 2], "Both cursors should have close called" - ) + # self.assertEqual( + # cursors_closed, [1, 2], "Both cursors should have close called" + # ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" From 2c293a5f66294bd0c5ded64825ec774eba000763 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 18 Jun 2025 09:27:03 +0530 Subject: [PATCH 46/86] changed order of stalling test Signed-off-by: Sai Shree Pradhan --- tests/unit/test_client.py | 56 +++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index cc41e6c87..d3a0c9866 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -352,7 +352,7 @@ def test_context_manager_closes_cursor(self): cursor.close.assert_called() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): + def test_a_context_manager_closes_connection(self, mock_client_class): print("hellow1") instance = mock_client_class.return_value @@ -790,41 +790,41 @@ def test_cursor_context_manager_handles_exit_exception(self): cursor.close.assert_called_once() - # def test_connection_close_handles_cursor_close_exception(self): - # """Test that _close handles exceptions from cursor.close() properly.""" - # print("banana") - # cursors_closed = [] + def test_connection_close_handles_cursor_close_exception(self): + """Test that _close handles exceptions from cursor.close() properly.""" + print("banana") + cursors_closed = [] - # def mock_close_with_exception(): - # cursors_closed.append(1) - # raise Exception("Test error during close") + def mock_close_with_exception(): + cursors_closed.append(1) + raise Exception("Test error during close") - # cursor1 = Mock() - # cursor1.close = mock_close_with_exception + cursor1 = Mock() + cursor1.close = mock_close_with_exception - # def mock_close_normal(): - # cursors_closed.append(2) + def mock_close_normal(): + cursors_closed.append(2) - # cursor2 = Mock() - # cursor2.close = mock_close_normal + cursor2 = Mock() + cursor2.close = mock_close_normal - # mock_backend = Mock() - # mock_session_handle = Mock() + mock_backend = Mock() + mock_session_handle = Mock() - # try: - # for cursor in [cursor1, cursor2]: - # try: - # cursor.close() - # except Exception: - # pass + try: + for cursor in [cursor1, cursor2]: + try: + cursor.close() + except Exception: + pass - # mock_backend.close_session(mock_session_handle) - # except Exception as e: - # self.fail(f"Connection close should handle exceptions: {e}") + mock_backend.close_session(mock_session_handle) + except Exception as e: + self.fail(f"Connection close should handle exceptions: {e}") - # self.assertEqual( - # cursors_closed, [1, 2], "Both cursors should have close called" - # ) + self.assertEqual( + cursors_closed, [1, 2], "Both cursors should have close called" + ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" From d237255818ef9246e417430d86be07475cf0a786 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 18 Jun 2025 10:40:40 +0530 Subject: [PATCH 47/86] removed debugging, added TelemetryClientFactory Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 10 +- src/databricks/sql/exc.py | 6 +- .../sql/telemetry/telemetry_client.py | 226 ++++++++++-------- tests/unit/test_client.py | 5 +- tests/unit/test_telemetry.py | 203 ++++++++-------- 5 files changed, 231 insertions(+), 219 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9359c4272..26705f3f8 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -54,9 +54,7 @@ ) from databricks.sql.telemetry.telemetry_client import ( TelemetryHelper, - initialize_telemetry_client, - get_telemetry_client, - close_telemetry_client, + TelemetryClientFactory, ) from databricks.sql.telemetry.models.enums import DatabricksClientType from databricks.sql.telemetry.models.event import ( @@ -308,14 +306,14 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=self.telemetry_enabled, session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, ) - self._telemetry_client = get_telemetry_client( + self._telemetry_client = TelemetryClientFactory.get_telemetry_client( session_id_hex=self.get_session_id_hex() ) @@ -472,7 +470,7 @@ def _close(self, close_cursors=True) -> None: self.open = False - close_telemetry_client(self.get_session_id_hex()) + TelemetryClientFactory.close(self.get_session_id_hex()) def commit(self): """No-op because Databricks does not support transactions""" diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 9ca662126..30fd6c26d 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,7 +1,7 @@ import json import logging -from databricks.sql.telemetry.telemetry_client import get_telemetry_client +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory logger = logging.getLogger(__name__) @@ -22,7 +22,9 @@ def __init__( error_name = self.__class__.__name__ if session_id_hex: - telemetry_client = get_telemetry_client(session_id_hex) + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) telemetry_client.export_failure_log(error_name, self.message) def __str__(self): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c7daff14f..f7fccf47a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -305,111 +305,131 @@ def close(self): self._flush() -# Module-level state -_clients: Dict[str, BaseTelemetryClient] = {} -_executor: Optional[ThreadPoolExecutor] = None -_initialized: bool = False -_lock = threading.Lock() -_original_excepthook = None -_excepthook_installed = False - - -def _initialize(): - """Initialize the telemetry system if not already initialized""" - global _initialized, _executor - if not _initialized: - _clients.clear() - _executor = ThreadPoolExecutor(max_workers=10) - _install_exception_hook() - _initialized = True - logger.debug("Telemetry system initialized with thread pool (max_workers=10)") - - -def _install_exception_hook(): - """Install global exception handler for unhandled exceptions""" - global _excepthook_installed, _original_excepthook - if not _excepthook_installed: - _original_excepthook = sys.excepthook - sys.excepthook = _handle_unhandled_exception - _excepthook_installed = True - logger.debug("Global exception handler installed for telemetry") - - -def _handle_unhandled_exception(exc_type, exc_value, exc_traceback): - """Handle unhandled exceptions by sending telemetry and flushing thread pool""" - logger.debug("Handling unhandled exception: %s", exc_type.__name__) - - clients_to_close = list(_clients.values()) - for client in clients_to_close: - client.close() - - # Call the original exception handler to maintain normal behavior - if _original_excepthook: - _original_excepthook(exc_type, exc_value, exc_traceback) - - -def initialize_telemetry_client( - telemetry_enabled, session_id_hex, auth_provider, host_url -): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" - try: - with _lock: - _initialize() - if session_id_hex not in _clients: - logger.debug( - "Creating new TelemetryClient for connection %s", session_id_hex - ) - if telemetry_enabled: - _clients[session_id_hex] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=_executor, - ) - print("i have initialized the telemetry client yes") - else: - _clients[session_id_hex] = NoopTelemetryClient() - print("i have initialized the noop client yes") - except Exception as e: - logger.debug("Failed to initialize telemetry client: %s", e) - # Fallback to NoopTelemetryClient to ensure connection doesn't fail - _clients[session_id_hex] = NoopTelemetryClient() - - -def get_telemetry_client(session_id_hex): - """Get the telemetry client for a specific connection""" - try: - if session_id_hex in _clients: - return _clients[session_id_hex] - else: - logger.error( - "Telemetry client not initialized for connection %s", session_id_hex +class TelemetryClientFactory: + """ + Static factory class for creating and managing telemetry clients. + It uses a thread pool to handle asynchronous operations. + """ + + _clients: Dict[ + str, BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient + _executor: Optional[ThreadPoolExecutor] = None + _initialized: bool = False + _lock = threading.Lock() # Thread safety for factory operations + _original_excepthook = None + _excepthook_installed = False + + @classmethod + def _initialize(cls): + """Initialize the factory if not already initialized""" + + if not cls._initialized: + cls._clients = {} + cls._executor = ThreadPoolExecutor( + max_workers=10 + ) # Thread pool for async operations TODO: Decide on max workers + cls._install_exception_hook() + cls._initialized = True + logger.debug( + "TelemetryClientFactory initialized with thread pool (max_workers=10)" ) - return NoopTelemetryClient() - except Exception as e: - logger.debug("Failed to get telemetry client: %s", e) - return NoopTelemetryClient() - - -def close_telemetry_client(session_id_hex): - """Remove the telemetry client for a specific connection""" - global _initialized, _executor - with _lock: - # if (telemetry_client := _clients.pop(session_id_hex, None)) is not None: - if session_id_hex in _clients: - telemetry_client = _clients.pop(session_id_hex) - logger.debug("Removing telemetry client for connection %s", session_id_hex) - telemetry_client.close() - - # Shutdown executor if no more clients + + @classmethod + def _install_exception_hook(cls): + """Install global exception handler for unhandled exceptions""" + if not cls._excepthook_installed: + cls._original_excepthook = sys.excepthook + sys.excepthook = cls._handle_unhandled_exception + cls._excepthook_installed = True + logger.debug("Global exception handler installed for telemetry") + + @classmethod + def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): + """Handle unhandled exceptions by sending telemetry and flushing thread pool""" + logger.debug("Handling unhandled exception: %s", exc_type.__name__) + + clients_to_close = list(cls._clients.values()) + for client in clients_to_close: + client.close() + + # Call the original exception handler to maintain normal behavior + if cls._original_excepthook: + cls._original_excepthook(exc_type, exc_value, exc_traceback) + + @staticmethod + def initialize_telemetry_client( + telemetry_enabled, + session_id_hex, + auth_provider, + host_url, + ): + """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - if not _clients and _executor: + + with TelemetryClientFactory._lock: + TelemetryClientFactory._initialize() + + if session_id_hex not in TelemetryClientFactory._clients: + logger.debug( + "Creating new TelemetryClient for connection %s", + session_id_hex, + ) + if telemetry_enabled: + TelemetryClientFactory._clients[ + session_id_hex + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + ) + else: + TelemetryClientFactory._clients[ + session_id_hex + ] = NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to initialize telemetry client: %s", e) + # Fallback to NoopTelemetryClient to ensure connection doesn't fail + TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() + + @staticmethod + def get_telemetry_client(session_id_hex): + """Get the telemetry client for a specific connection""" + try: + if session_id_hex in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[session_id_hex] + else: + logger.error( + "Telemetry client not initialized for connection %s", + session_id_hex, + ) + return NoopTelemetryClient() + except Exception as e: + logger.debug("Failed to get telemetry client: %s", e) + return NoopTelemetryClient() + + @staticmethod + def close(session_id_hex): + """Close and remove the telemetry client for a specific connection""" + + with TelemetryClientFactory._lock: + if ( + telemetry_client := TelemetryClientFactory._clients.pop( + session_id_hex, None + ) + ) is not None: + logger.debug( + "Removing telemetry client for connection %s", session_id_hex + ) + telemetry_client.close() + + # Shutdown executor if no more clients + if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - _executor.shutdown(wait=True) - _executor = None - _initialized = False - except Exception as e: - logger.debug("Failed to shutdown thread pool executor: %s", e) + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d3a0c9866..f9206dc27 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -336,7 +336,6 @@ def test_negative_fetch_throws_exception(self): result_set.fetchmany(-1) def test_context_manager_closes_cursor(self): - print("hellow") mock_close = Mock() with client.Cursor(Mock(), Mock()) as cursor: cursor.close = mock_close @@ -352,8 +351,7 @@ def test_context_manager_closes_cursor(self): cursor.close.assert_called() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_a_context_manager_closes_connection(self, mock_client_class): - print("hellow1") + def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() @@ -792,7 +790,6 @@ def test_cursor_context_manager_handles_exit_exception(self): def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" - print("banana") cursors_closed = [] def mock_close_with_exception(): diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 84833dafd..699480bbe 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -6,9 +6,7 @@ from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, NoopTelemetryClient, - initialize_telemetry_client, - get_telemetry_client, - close_telemetry_client, + TelemetryClientFactory, TelemetryHelper, BaseTelemetryClient ) @@ -63,19 +61,18 @@ def telemetry_client_setup(): def telemetry_system_reset(): """Fixture to reset telemetry system state before each test.""" # Reset the static state before each test - from databricks.sql.telemetry.telemetry_client import _clients, _executor, _initialized - _clients.clear() - if _executor: - _executor.shutdown(wait=True) - _executor = None - _initialized = False + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False yield # Cleanup after test if needed - _clients.clear() - if _executor: - _executor.shutdown(wait=True) - _executor = None - _initialized = False + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False class TestNoopTelemetryClient: @@ -335,6 +332,77 @@ class TestBaseClient(BaseTelemetryClient): TestBaseClient() # Can't instantiate abstract class +class TestTelemetryHelper: + """Tests for the TelemetryHelper class.""" + + def test_get_driver_system_configuration(self): + """Test getting driver system configuration.""" + config = TelemetryHelper.get_driver_system_configuration() + + assert isinstance(config.driver_name, str) + assert isinstance(config.driver_version, str) + assert isinstance(config.runtime_name, str) + assert isinstance(config.runtime_vendor, str) + assert isinstance(config.runtime_version, str) + assert isinstance(config.os_name, str) + assert isinstance(config.os_version, str) + assert isinstance(config.os_arch, str) + assert isinstance(config.locale_name, str) + assert isinstance(config.char_set_encoding, str) + + assert config.driver_name == "Databricks SQL Python Connector" + assert "Python" in config.runtime_name + assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] + assert config.os_name in ["Darwin", "Linux", "Windows"] + + # Verify caching behavior + config2 = TelemetryHelper.get_driver_system_configuration() + assert config is config2 # Should return same instance + + def test_get_auth_mechanism(self): + """Test getting auth mechanism for different auth providers.""" + # Test PAT auth + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT + + # Test OAuth auth + oauth_auth = MagicMock(spec=DatabricksOAuthProvider) + assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH + + # Test External auth + external_auth = MagicMock(spec=ExternalAuthProvider) + assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH + + # Test None auth provider + assert TelemetryHelper.get_auth_mechanism(None) is None + + # Test unknown auth provider + unknown_auth = MagicMock() + assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT + + def test_get_auth_flow(self): + """Test getting auth flow for different OAuth providers.""" + # Test OAuth with existing tokens + oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_tokens._access_token = "test-access-token" + oauth_with_tokens._refresh_token = "test-refresh-token" + assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH + + # Test OAuth with browser-based auth + oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) + oauth_with_browser._access_token = None + oauth_with_browser._refresh_token = None + oauth_with_browser.oauth_manager = MagicMock() + assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION + + # Test non-OAuth provider + pat_auth = AccessTokenAuthProvider("test-token") + assert TelemetryHelper.get_auth_flow(pat_auth) is None + + # Test None auth provider + assert TelemetryHelper.get_auth_flow(None) is None + + class TestTelemetrySystem: """Tests for the telemetry system functions.""" @@ -344,14 +412,14 @@ def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): auth_provider = MagicMock() host_url = "test-host" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex assert client._auth_provider == auth_provider @@ -360,19 +428,19 @@ def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): """Test initializing a telemetry client when telemetry is disabled.""" session_id_hex = "test-uuid" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): """Test getting a non-existent telemetry client.""" - client = get_telemetry_client("nonexistent-uuid") + client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client(self, telemetry_system_reset): @@ -381,126 +449,53 @@ def test_close_telemetry_client(self, telemetry_system_reset): auth_provider = MagicMock() host_url = "test-host" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, host_url=host_url, ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) client.close = MagicMock() - close_telemetry_client(session_id_hex) + TelemetryClientFactory.close(session_id_hex) client.close.assert_called_once() - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) def test_close_telemetry_client_noop(self, telemetry_system_reset): """Test closing a no-op telemetry client.""" session_id_hex = "test-uuid" - initialize_telemetry_client( + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, auth_provider=MagicMock(), host_url="test-host", ) - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) client.close = MagicMock() - close_telemetry_client(session_id_hex) + TelemetryClientFactory.close(session_id_hex) client.close.assert_called_once() - client = get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client._handle_unhandled_exception") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory._handle_unhandled_exception") def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): """Test that global exception hook is installed and handles exceptions.""" - from databricks.sql.telemetry.telemetry_client import _install_exception_hook, _handle_unhandled_exception - - _install_exception_hook() + TelemetryClientFactory._install_exception_hook() test_exception = ValueError("Test exception") - _handle_unhandled_exception(type(test_exception), test_exception, None) + TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) - - -class TestTelemetryHelper: - """Tests for the TelemetryHelper class.""" - - def test_get_driver_system_configuration(self): - """Test getting driver system configuration.""" - config = TelemetryHelper.get_driver_system_configuration() - - assert isinstance(config.driver_name, str) - assert isinstance(config.driver_version, str) - assert isinstance(config.runtime_name, str) - assert isinstance(config.runtime_vendor, str) - assert isinstance(config.runtime_version, str) - assert isinstance(config.os_name, str) - assert isinstance(config.os_version, str) - assert isinstance(config.os_arch, str) - assert isinstance(config.locale_name, str) - assert isinstance(config.char_set_encoding, str) - - assert config.driver_name == "Databricks SQL Python Connector" - assert "Python" in config.runtime_name - assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] - assert config.os_name in ["Darwin", "Linux", "Windows"] - - # Verify caching behavior - config2 = TelemetryHelper.get_driver_system_configuration() - assert config is config2 # Should return same instance - - def test_get_auth_mechanism(self): - """Test getting auth mechanism for different auth providers.""" - # Test PAT auth - pat_auth = AccessTokenAuthProvider("test-token") - assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT - - # Test OAuth auth - oauth_auth = MagicMock(spec=DatabricksOAuthProvider) - assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH - - # Test External auth - external_auth = MagicMock(spec=ExternalAuthProvider) - assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH - - # Test None auth provider - assert TelemetryHelper.get_auth_mechanism(None) is None - - # Test unknown auth provider - unknown_auth = MagicMock() - assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT - - def test_get_auth_flow(self): - """Test getting auth flow for different OAuth providers.""" - # Test OAuth with existing tokens - oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) - oauth_with_tokens._access_token = "test-access-token" - oauth_with_tokens._refresh_token = "test-refresh-token" - assert TelemetryHelper.get_auth_flow(oauth_with_tokens) == AuthFlow.TOKEN_PASSTHROUGH - - # Test OAuth with browser-based auth - oauth_with_browser = MagicMock(spec=DatabricksOAuthProvider) - oauth_with_browser._access_token = None - oauth_with_browser._refresh_token = None - oauth_with_browser.oauth_manager = MagicMock() - assert TelemetryHelper.get_auth_flow(oauth_with_browser) == AuthFlow.BROWSER_BASED_AUTHENTICATION - - # Test non-OAuth provider - pat_auth = AccessTokenAuthProvider("test-token") - assert TelemetryHelper.get_auth_flow(pat_auth) is None - - # Test None auth provider - assert TelemetryHelper.get_auth_flow(None) is None \ No newline at end of file + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file From f101b198d7e23b51f6758fbd0869c658910e9a65 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 18 Jun 2025 10:44:09 +0530 Subject: [PATCH 48/86] remove more debugging Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- tests/unit/test_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index df6a0e169..462d22369 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s + run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest strategy: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f9206dc27..588b0d70e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -52,6 +52,7 @@ def new(cls): ) ThriftBackendMock.execute_command.return_value = MockTExecuteStatementResp + return ThriftBackendMock @classmethod From a0946593920b67c2429f0036136fcab169b8704f Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 11:18:26 +0530 Subject: [PATCH 49/86] latency logs funcitionality Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 17 ++- .../sql/telemetry/latency_logger.py | 131 ++++++++++++++++++ .../sql/telemetry/telemetry_client.py | 39 ++++++ src/databricks/sql/thrift_backend.py | 3 + 4 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 src/databricks/sql/telemetry/latency_logger.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 26705f3f8..c085f8f5e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -61,7 +61,7 @@ DriverConnectionParameters, HostDetails, ) - +from databricks.sql.telemetry.latency_logger import log_latency logger = logging.getLogger(__name__) @@ -758,6 +758,7 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency() def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None ): @@ -797,6 +798,7 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) + @log_latency() def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None ): @@ -824,6 +826,7 @@ def _handle_staging_get( with open(local_file, "wb") as fp: fp.write(r.content) + @log_latency() def _handle_staging_remove( self, presigned_url: str, headers: Optional[dict] = None ): @@ -837,6 +840,7 @@ def _handle_staging_remove( session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency() def execute( self, operation: str, @@ -927,6 +931,7 @@ def execute( return self + @log_latency() def execute_async( self, operation: str, @@ -1052,6 +1057,7 @@ def executemany(self, operation, seq_of_parameters): self.execute(operation, parameters) return self + @log_latency() def catalogs(self) -> "Cursor": """ Get all available catalogs. @@ -1075,6 +1081,7 @@ def catalogs(self) -> "Cursor": ) return self + @log_latency() def schemas( self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None ) -> "Cursor": @@ -1103,6 +1110,7 @@ def schemas( ) return self + @log_latency() def tables( self, catalog_name: Optional[str] = None, @@ -1138,6 +1146,7 @@ def tables( ) return self + @log_latency() def columns( self, catalog_name: Optional[str] = None, @@ -1173,6 +1182,7 @@ def columns( ) return self + @log_latency() def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a sequence of sequences. @@ -1206,6 +1216,7 @@ def fetchone(self) -> Optional[Row]: session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency() def fetchmany(self, size: int) -> List[Row]: """ Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a @@ -1231,6 +1242,7 @@ def fetchmany(self, size: int) -> List[Row]: session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency() def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: @@ -1241,6 +1253,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency() def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: @@ -1406,6 +1419,7 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows + @log_latency() def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) @@ -1418,6 +1432,7 @@ def _convert_columnar_table(self, table): return result + @log_latency() def _convert_arrow_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py new file mode 100644 index 000000000..955f0a6e1 --- /dev/null +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -0,0 +1,131 @@ +import time +import functools +from typing import Optional +from uuid import UUID +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.models.event import ( + SqlExecutionEvent, +) +from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType +from databricks.sql.utils import ColumnQueue + +# Helper to get statement_id/query_id from instance if available +def _get_statement_id(instance) -> Optional[str]: + """ + Get statement ID from an instance using various methods: + 1. For Cursor: Use query_id property which returns UUID from active_op_handle + 2. For ResultSet: Use command_id which contains operationId + + Note: ThriftBackend itself doesn't have a statement ID since one backend + can handle multiple concurrent operations/cursors. + """ + if hasattr(instance, "query_id"): + return instance.query_id + + if hasattr(instance, "command_id") and instance.command_id: + return str(UUID(bytes=instance.command_id.operationId.guid)) + + return None + + +def _get_session_id_hex(instance) -> Optional[str]: + if hasattr(instance, "connection") and instance.connection: + return instance.connection.get_session_id_hex() + if hasattr(instance, "get_session_id_hex"): + return instance.get_session_id_hex() + return None + + +def _get_statement_type(func_name: str) -> StatementType: # TODO: implement this + return StatementType.SQL + + +def _get_is_compressed(instance) -> bool: + """ + Get compression status from instance: + 1. Direct lz4_compression attribute (Connection) + 2. Through connection attribute (Cursor/ResultSet) + 3. Through thrift_backend attribute (Cursor) + """ + if hasattr(instance, "lz4_compression"): + return instance.lz4_compression + if hasattr(instance, "connection") and instance.connection: + return instance.connection.lz4_compression + if hasattr(instance, "thrift_backend") and instance.thrift_backend: + return instance.thrift_backend.lz4_compressed + return False + + +def _get_execution_result(instance) -> ExecutionResultFormat: + """ + Get execution result format from instance: + 1. For ResultSet: Check if using cloud fetch (external_links) or arrow/columnar format + 2. For Cursor: Check through active_result_set + 3. For ThriftBackend: Check result format from server + """ + if hasattr(instance, "_use_cloud_fetch") and instance._use_cloud_fetch: + return ExecutionResultFormat.EXTERNAL_LINKS + + if hasattr(instance, "active_result_set") and instance.active_result_set: + if isinstance(instance.active_result_set.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + return ExecutionResultFormat.INLINE_ARROW + + if hasattr(instance, "thrift_backend") and instance.thrift_backend: + if hasattr(instance.thrift_backend, "_use_arrow_native_complex_types"): + return ExecutionResultFormat.INLINE_ARROW + + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + +def _get_retry_count(instance) -> int: + """ + Get retry count from instance by checking retry_policy.history length. + The retry_policy is only accessible through thrift_backend. + """ + # TODO: implement this + + return 0 + + +def log_latency(): + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + start_time = time.perf_counter() + result = None + try: + result = func(self, *args, **kwargs) + return result + finally: + end_time = time.perf_counter() + duration_ms = int((end_time - start_time) * 1000) + + session_id_hex = _get_session_id_hex(self) + + if session_id_hex: + statement_id = _get_statement_id(self) + statement_type = _get_statement_type(func.__name__) + is_compressed = _get_is_compressed(self) + execution_result = _get_execution_result(self) + retry_count = _get_retry_count(self) + + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=is_compressed, + execution_result=execution_result, + retry_count=retry_count, + ) + + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=statement_id, + ) + + return wrapper + + return decorator diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f7fccf47a..6933df7ad 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -112,6 +112,12 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): def export_failure_log(self, error_name, error_message): raise NotImplementedError("Subclasses must implement export_failure_log") + @abstractmethod + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id=None + ): + raise NotImplementedError("Subclasses must implement export_latency_log") + @abstractmethod def close(self): raise NotImplementedError("Subclasses must implement close") @@ -136,6 +142,11 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): def export_failure_log(self, error_name, error_message): pass + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id=None + ): + pass + def close(self): pass @@ -299,6 +310,34 @@ def export_failure_log(self, error_name, error_message): except Exception as e: logger.debug("Failed to export failure log: %s", e) + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id=None + ): + logger.debug("Exporting latency log for connection %s", self._session_id_hex) + try: + telemetry_frontend_log = TelemetryFrontendLog( + frontend_log_event_id=str(uuid.uuid4()), + context=FrontendLogContext( + client_context=TelemetryClientContext( + timestamp_millis=int(time.time() * 1000), + user_agent=self._user_agent, + ) + ), + entry=FrontendLogEntry( + sql_driver_log=TelemetryEvent( + session_id=self._session_id_hex, + system_configuration=TelemetryHelper.get_driver_system_configuration(), + driver_connection_params=self._driver_connection_params, + sql_statement_id=sql_statement_id, + sql_operation=sql_execution_event, + operation_latency_ms=latency_ms, + ) + ), + ) + self._export_event(telemetry_frontend_log) + except Exception as e: + logger.debug("Failed to export latency log: %s", e) + def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 78683ac31..e1798e643 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -583,6 +583,9 @@ def open_session(self, session_configuration, catalog, schema): self._transport.close() raise + def get_session_id_hex(self) -> str: + return self._session_id_hex + def close_session(self, session_handle) -> None: req = ttypes.TCloseSessionReq(sessionHandle=session_handle) try: From fc918d6641412f448cbd042d72b8f4c915f40c43 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 14:49:32 +0530 Subject: [PATCH 50/86] fixed type of return value in get_session_id_hex() in thrift backend Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/thrift_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e1798e643..d7f9bdd06 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -41,6 +41,7 @@ convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions +from typing import Optional logger = logging.getLogger(__name__) @@ -583,7 +584,7 @@ def open_session(self, session_configuration, catalog, schema): self._transport.close() raise - def get_session_id_hex(self) -> str: + def get_session_id_hex(self) -> Optional[str]: return self._session_id_hex def close_session(self, session_handle) -> None: From d7c75d7cbbf71d3503234ae7b5b6bee6a9c167a9 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:07:17 +0530 Subject: [PATCH 51/86] debug on TelemetryClientFactory lock Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- .../sql/telemetry/telemetry_client.py | 73 ++++++++++++++++++- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 462d22369..244118195 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit -v -s check-linting: runs-on: ubuntu-latest strategy: diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 6933df7ad..ef5df1ca5 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -31,6 +31,76 @@ logger = logging.getLogger(__name__) +class DebugLock: + """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" + def __init__(self, name: str = "DebugLock"): + self._lock = threading.Lock() + self._name = name + self._owner = None + self._waiters = [] + self._debug_logger = logging.getLogger(f"{__name__}.{name}") + # Ensure debug logging is visible + if not self._debug_logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" + ) + handler.setFormatter(formatter) + self._debug_logger.addHandler(handler) + self._debug_logger.setLevel(logging.DEBUG) + def acquire(self, blocking=True, timeout=-1): + current = threading.current_thread() + thread_info = f"{current.name}-{current.ident}" + if self._owner: + self._debug_logger.warning( + f":rotating_light: WAITING: {thread_info} waiting for lock held by {self._owner}" + ) + self._waiters.append(thread_info) + else: + self._debug_logger.debug( + f":large_green_circle: TRYING: {thread_info} attempting to acquire lock" + ) + # Try to acquire the lock + acquired = self._lock.acquire(blocking, timeout) + if acquired: + self._owner = thread_info + self._debug_logger.info(f":white_check_mark: ACQUIRED: {thread_info} got the lock") + if self._waiters: + self._debug_logger.info( + f":clipboard: WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" + ) + else: + self._debug_logger.error( + f":x: FAILED: {thread_info} failed to acquire lock (timeout)" + ) + if thread_info in self._waiters: + self._waiters.remove(thread_info) + return acquired + def release(self): + current = threading.current_thread() + thread_info = f"{current.name}-{current.ident}" + if self._owner != thread_info: + self._debug_logger.error( + f":rotating_light: ERROR: {thread_info} trying to release lock owned by {self._owner}" + ) + else: + self._debug_logger.info(f":unlock: RELEASED: {thread_info} released the lock") + self._owner = None + # Remove from waiters if present + if thread_info in self._waiters: + self._waiters.remove(thread_info) + if self._waiters: + self._debug_logger.info( + f":loudspeaker: NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" + ) + self._lock.release() + def __enter__(self): + self.acquire() + return self + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + class TelemetryHelper: """Helper class for getting telemetry related information.""" @@ -355,7 +425,8 @@ class TelemetryClientFactory: ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False - _lock = threading.Lock() # Thread safety for factory operations + # _lock = threading.Lock() # Thread safety for factory operations + _lock = DebugLock("TelemetryClientFactory") # Thread safety for factory operations with debugging _original_excepthook = None _excepthook_installed = False From b6b0f8948f20c5c501522243eca515943c20c61c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:09:06 +0530 Subject: [PATCH 52/86] formatting Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index ef5df1ca5..482986e04 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -33,6 +33,7 @@ class DebugLock: """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" + def __init__(self, name: str = "DebugLock"): self._lock = threading.Lock() self._name = name @@ -48,6 +49,7 @@ def __init__(self, name: str = "DebugLock"): handler.setFormatter(formatter) self._debug_logger.addHandler(handler) self._debug_logger.setLevel(logging.DEBUG) + def acquire(self, blocking=True, timeout=-1): current = threading.current_thread() thread_info = f"{current.name}-{current.ident}" @@ -64,7 +66,9 @@ def acquire(self, blocking=True, timeout=-1): acquired = self._lock.acquire(blocking, timeout) if acquired: self._owner = thread_info - self._debug_logger.info(f":white_check_mark: ACQUIRED: {thread_info} got the lock") + self._debug_logger.info( + f":white_check_mark: ACQUIRED: {thread_info} got the lock" + ) if self._waiters: self._debug_logger.info( f":clipboard: WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" @@ -76,6 +80,7 @@ def acquire(self, blocking=True, timeout=-1): if thread_info in self._waiters: self._waiters.remove(thread_info) return acquired + def release(self): current = threading.current_thread() thread_info = f"{current.name}-{current.ident}" @@ -84,7 +89,9 @@ def release(self): f":rotating_light: ERROR: {thread_info} trying to release lock owned by {self._owner}" ) else: - self._debug_logger.info(f":unlock: RELEASED: {thread_info} released the lock") + self._debug_logger.info( + f":unlock: RELEASED: {thread_info} released the lock" + ) self._owner = None # Remove from waiters if present if thread_info in self._waiters: @@ -94,9 +101,11 @@ def release(self): f":loudspeaker: NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" ) self._lock.release() + def __enter__(self): self.acquire() return self + def __exit__(self, exc_type, exc_val, exc_tb): self.release() @@ -426,7 +435,9 @@ class TelemetryClientFactory: _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False # _lock = threading.Lock() # Thread safety for factory operations - _lock = DebugLock("TelemetryClientFactory") # Thread safety for factory operations with debugging + _lock = DebugLock( + "TelemetryClientFactory" + ) # Thread safety for factory operations with debugging _original_excepthook = None _excepthook_installed = False From 50a12060c0311939fdc37c14d40072b0313664cc Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:17:14 +0530 Subject: [PATCH 53/86] type notation for _waiters Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 482986e04..dc598cff6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -4,7 +4,7 @@ import requests import logging from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional +from typing import Dict, Optional, List from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -37,8 +37,8 @@ class DebugLock: def __init__(self, name: str = "DebugLock"): self._lock = threading.Lock() self._name = name - self._owner = None - self._waiters = [] + self._owner: Optional[str] = None + self._waiters: List[str] = [] self._debug_logger = logging.getLogger(f"{__name__}.{name}") # Ensure debug logging is visible if not self._debug_logger.handlers: @@ -55,27 +55,25 @@ def acquire(self, blocking=True, timeout=-1): thread_info = f"{current.name}-{current.ident}" if self._owner: self._debug_logger.warning( - f":rotating_light: WAITING: {thread_info} waiting for lock held by {self._owner}" + f": WAITING: {thread_info} waiting for lock held by {self._owner}" ) self._waiters.append(thread_info) else: self._debug_logger.debug( - f":large_green_circle: TRYING: {thread_info} attempting to acquire lock" + f": TRYING: {thread_info} attempting to acquire lock" ) # Try to acquire the lock acquired = self._lock.acquire(blocking, timeout) if acquired: self._owner = thread_info - self._debug_logger.info( - f":white_check_mark: ACQUIRED: {thread_info} got the lock" - ) + self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") if self._waiters: self._debug_logger.info( - f":clipboard: WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" + f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" ) else: self._debug_logger.error( - f":x: FAILED: {thread_info} failed to acquire lock (timeout)" + f": FAILED: {thread_info} failed to acquire lock (timeout)" ) if thread_info in self._waiters: self._waiters.remove(thread_info) @@ -86,19 +84,17 @@ def release(self): thread_info = f"{current.name}-{current.ident}" if self._owner != thread_info: self._debug_logger.error( - f":rotating_light: ERROR: {thread_info} trying to release lock owned by {self._owner}" + f": ERROR: {thread_info} trying to release lock owned by {self._owner}" ) else: - self._debug_logger.info( - f":unlock: RELEASED: {thread_info} released the lock" - ) + self._debug_logger.info(f": RELEASED: {thread_info} released the lock") self._owner = None # Remove from waiters if present if thread_info in self._waiters: self._waiters.remove(thread_info) if self._waiters: self._debug_logger.info( - f":loudspeaker: NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" + f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" ) self._lock.release() From 39a053060949ef8c60228b073553efd014e850e7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:29:50 +0530 Subject: [PATCH 54/86] called connection.close() in test_arraysize_buffer_size_passthrough Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- tests/unit/test_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 244118195..a40357ec9 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s + run: poetry run python -m pytest tests/unit/test_client.py -v -s check-linting: runs-on: ubuntu-latest strategy: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 588b0d70e..62bc9ee13 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -264,6 +264,7 @@ def test_arraysize_buffer_size_passthrough( self.assertEqual(kwargs["arraysize"], 999) self.assertEqual(kwargs["result_buffer_size_bytes"], 1234) + connection.close() def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() From 413427f6fa23f5bcbef8e68e688da2da5932ac93 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:33:10 +0530 Subject: [PATCH 55/86] run all unit tests Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index a40357ec9..244118195 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit/test_client.py -v -s + run: poetry run python -m pytest tests/unit -v -s check-linting: runs-on: ubuntu-latest strategy: From 6b1d1b8c4a255159061039949eb6fef1b62d3816 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:39:59 +0530 Subject: [PATCH 56/86] more debugging Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index dc598cff6..2da8e07ba 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -465,7 +465,7 @@ def _install_exception_hook(cls): def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) - + print("Handling unhandled exception: %s", exc_type.__name__) clients_to_close = list(cls._clients.values()) for client in clients_to_close: client.close() @@ -530,8 +530,9 @@ def get_telemetry_client(session_id_hex): @staticmethod def close(session_id_hex): """Close and remove the telemetry client for a specific connection""" - + print("Closing telemetry client: %s", session_id_hex) with TelemetryClientFactory._lock: + print("Closing telemetry client, got lock: %s", session_id_hex) if ( telemetry_client := TelemetryClientFactory._clients.pop( session_id_hex, None @@ -547,6 +548,8 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) + print("Shutting down thread pool executor: %s", session_id_hex) TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False + print("Thread pool executor shut down: %s", session_id_hex) From 8f5e5ba85c19d26d28f4d03781b439f658131eb6 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:49:53 +0530 Subject: [PATCH 57/86] removed the connection.close() from that test, put debug statement before and after TelemetryClientFactory lock Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 158 +++++++++--------- tests/unit/test_client.py | 1 - 2 files changed, 80 insertions(+), 79 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2da8e07ba..8c3e6b693 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -31,79 +31,79 @@ logger = logging.getLogger(__name__) -class DebugLock: - """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" - - def __init__(self, name: str = "DebugLock"): - self._lock = threading.Lock() - self._name = name - self._owner: Optional[str] = None - self._waiters: List[str] = [] - self._debug_logger = logging.getLogger(f"{__name__}.{name}") - # Ensure debug logging is visible - if not self._debug_logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter( - ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" - ) - handler.setFormatter(formatter) - self._debug_logger.addHandler(handler) - self._debug_logger.setLevel(logging.DEBUG) - - def acquire(self, blocking=True, timeout=-1): - current = threading.current_thread() - thread_info = f"{current.name}-{current.ident}" - if self._owner: - self._debug_logger.warning( - f": WAITING: {thread_info} waiting for lock held by {self._owner}" - ) - self._waiters.append(thread_info) - else: - self._debug_logger.debug( - f": TRYING: {thread_info} attempting to acquire lock" - ) - # Try to acquire the lock - acquired = self._lock.acquire(blocking, timeout) - if acquired: - self._owner = thread_info - self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") - if self._waiters: - self._debug_logger.info( - f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" - ) - else: - self._debug_logger.error( - f": FAILED: {thread_info} failed to acquire lock (timeout)" - ) - if thread_info in self._waiters: - self._waiters.remove(thread_info) - return acquired - - def release(self): - current = threading.current_thread() - thread_info = f"{current.name}-{current.ident}" - if self._owner != thread_info: - self._debug_logger.error( - f": ERROR: {thread_info} trying to release lock owned by {self._owner}" - ) - else: - self._debug_logger.info(f": RELEASED: {thread_info} released the lock") - self._owner = None - # Remove from waiters if present - if thread_info in self._waiters: - self._waiters.remove(thread_info) - if self._waiters: - self._debug_logger.info( - f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" - ) - self._lock.release() - - def __enter__(self): - self.acquire() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.release() +# class DebugLock: +# """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" + +# def __init__(self, name: str = "DebugLock"): +# self._lock = threading.Lock() +# self._name = name +# self._owner: Optional[str] = None +# self._waiters: List[str] = [] +# self._debug_logger = logging.getLogger(f"{__name__}.{name}") +# # Ensure debug logging is visible +# if not self._debug_logger.handlers: +# handler = logging.StreamHandler() +# formatter = logging.Formatter( +# ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" +# ) +# handler.setFormatter(formatter) +# self._debug_logger.addHandler(handler) +# self._debug_logger.setLevel(logging.DEBUG) + +# def acquire(self, blocking=True, timeout=-1): +# current = threading.current_thread() +# thread_info = f"{current.name}-{current.ident}" +# if self._owner: +# self._debug_logger.warning( +# f": WAITING: {thread_info} waiting for lock held by {self._owner}" +# ) +# self._waiters.append(thread_info) +# else: +# self._debug_logger.debug( +# f": TRYING: {thread_info} attempting to acquire lock" +# ) +# # Try to acquire the lock +# acquired = self._lock.acquire(blocking, timeout) +# if acquired: +# self._owner = thread_info +# self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") +# if self._waiters: +# self._debug_logger.info( +# f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" +# ) +# else: +# self._debug_logger.error( +# f": FAILED: {thread_info} failed to acquire lock (timeout)" +# ) +# if thread_info in self._waiters: +# self._waiters.remove(thread_info) +# return acquired + +# def release(self): +# current = threading.current_thread() +# thread_info = f"{current.name}-{current.ident}" +# if self._owner != thread_info: +# self._debug_logger.error( +# f": ERROR: {thread_info} trying to release lock owned by {self._owner}" +# ) +# else: +# self._debug_logger.info(f": RELEASED: {thread_info} released the lock") +# self._owner = None +# # Remove from waiters if present +# if thread_info in self._waiters: +# self._waiters.remove(thread_info) +# if self._waiters: +# self._debug_logger.info( +# f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" +# ) +# self._lock.release() + +# def __enter__(self): +# self.acquire() +# return self + +# def __exit__(self, exc_type, exc_val, exc_tb): +# self.release() class TelemetryHelper: @@ -430,10 +430,10 @@ class TelemetryClientFactory: ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False - # _lock = threading.Lock() # Thread safety for factory operations - _lock = DebugLock( - "TelemetryClientFactory" - ) # Thread safety for factory operations with debugging + _lock = threading.Lock() # Thread safety for factory operations + # _lock = DebugLock( + # "TelemetryClientFactory" + # ) # Thread safety for factory operations with debugging _original_excepthook = None _excepthook_installed = False @@ -483,8 +483,9 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - + print("Initializing telemetry client: %s", session_id_hex) with TelemetryClientFactory._lock: + print("Initializing telemetry client, got lock: %s", session_id_hex) TelemetryClientFactory._initialize() if session_id_hex not in TelemetryClientFactory._clients: @@ -506,6 +507,7 @@ def initialize_telemetry_client( TelemetryClientFactory._clients[ session_id_hex ] = NoopTelemetryClient() + print("Telemetry client initialized: %s", session_id_hex) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 62bc9ee13..588b0d70e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -264,7 +264,6 @@ def test_arraysize_buffer_size_passthrough( self.assertEqual(kwargs["arraysize"], 999) self.assertEqual(kwargs["result_buffer_size_bytes"], 1234) - connection.close() def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() From 2dc00ba69f8b96c3e9fb0e8febe75e6061726d53 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 15:56:19 +0530 Subject: [PATCH 58/86] more debug Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 38 +++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 8c3e6b693..e117cb2bb 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -483,9 +483,17 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - print("Initializing telemetry client: %s", session_id_hex) + print( + "\nWAITING: Initializing telemetry client: %s", + session_id_hex, + flush=True, + ) with TelemetryClientFactory._lock: - print("Initializing telemetry client, got lock: %s", session_id_hex) + print( + "\nACQUIRED: Initializing telemetry client, got lock: %s", + session_id_hex, + flush=True, + ) TelemetryClientFactory._initialize() if session_id_hex not in TelemetryClientFactory._clients: @@ -507,7 +515,11 @@ def initialize_telemetry_client( TelemetryClientFactory._clients[ session_id_hex ] = NoopTelemetryClient() - print("Telemetry client initialized: %s", session_id_hex) + print( + "\nRELEASED: Telemetry client initialized: %s", + session_id_hex, + flush=True, + ) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail @@ -532,9 +544,13 @@ def get_telemetry_client(session_id_hex): @staticmethod def close(session_id_hex): """Close and remove the telemetry client for a specific connection""" - print("Closing telemetry client: %s", session_id_hex) + print("\nWAITING: Closing telemetry client: %s", session_id_hex, flush=True) with TelemetryClientFactory._lock: - print("Closing telemetry client, got lock: %s", session_id_hex) + print( + "\nACQUIRED: Closing telemetry client, got lock: %s", + session_id_hex, + flush=True, + ) if ( telemetry_client := TelemetryClientFactory._clients.pop( session_id_hex, None @@ -550,8 +566,16 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - print("Shutting down thread pool executor: %s", session_id_hex) + print( + "\nSHUTDOWN: Shutting down thread pool executor: %s", + session_id_hex, + flush=True, + ) TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - print("Thread pool executor shut down: %s", session_id_hex) + print( + "\nRELEASED: Thread pool executor shut down: %s", + session_id_hex, + flush=True, + ) From 1ff03d4ed6d521a17b539d15e8ce75664c44796d Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 16:02:27 +0530 Subject: [PATCH 59/86] more more more Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e117cb2bb..e11053116 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -495,6 +495,11 @@ def initialize_telemetry_client( flush=True, ) TelemetryClientFactory._initialize() + print( + "\n TelemetryClientFactory initialized: %s", + session_id_hex, + flush=True, + ) if session_id_hex not in TelemetryClientFactory._clients: logger.debug( @@ -502,6 +507,7 @@ def initialize_telemetry_client( session_id_hex, ) if telemetry_enabled: + print("\nTelemetry enabled: %s", session_id_hex, flush=True) TelemetryClientFactory._clients[ session_id_hex ] = TelemetryClient( @@ -511,10 +517,21 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, ) + print( + "\n Telemetry client initialized: %s", + session_id_hex, + flush=True, + ) else: + print("\nTelemetry disabled: %s", session_id_hex, flush=True) TelemetryClientFactory._clients[ session_id_hex ] = NoopTelemetryClient() + print( + "\n Noop Telemetry client initialized: %s", + session_id_hex, + flush=True, + ) print( "\nRELEASED: Telemetry client initialized: %s", session_id_hex, From 6ff07c821d55c27fb6e94c870380f47631a0b751 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 16:07:41 +0530 Subject: [PATCH 60/86] why Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e11053116..3ea4866af 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -496,18 +496,23 @@ def initialize_telemetry_client( ) TelemetryClientFactory._initialize() print( - "\n TelemetryClientFactory initialized: %s", + "\n TelemetryClientFactory initialized: %s", session_id_hex, flush=True, ) if session_id_hex not in TelemetryClientFactory._clients: + print( + "\n Session ID not in clients: %s", + session_id_hex, + flush=True, + ) logger.debug( "Creating new TelemetryClient for connection %s", session_id_hex, ) if telemetry_enabled: - print("\nTelemetry enabled: %s", session_id_hex, flush=True) + print("\n Telemetry enabled: %s", session_id_hex, flush=True) TelemetryClientFactory._clients[ session_id_hex ] = TelemetryClient( @@ -518,17 +523,19 @@ def initialize_telemetry_client( executor=TelemetryClientFactory._executor, ) print( - "\n Telemetry client initialized: %s", + "\n Telemetry client initialized: %s", session_id_hex, flush=True, ) else: - print("\nTelemetry disabled: %s", session_id_hex, flush=True) + print( + "\n Telemetry disabled: %s", session_id_hex, flush=True + ) TelemetryClientFactory._clients[ session_id_hex ] = NoopTelemetryClient() print( - "\n Noop Telemetry client initialized: %s", + "\n Noop Telemetry client initialized: %s", session_id_hex, flush=True, ) From 395049a7675e72d961b8d803da575074289e2b6e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 16:11:56 +0530 Subject: [PATCH 61/86] whywhy Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 3ea4866af..10871f310 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -539,12 +539,24 @@ def initialize_telemetry_client( session_id_hex, flush=True, ) + else: + print( + "\n Session ID already in clients: %s", + session_id_hex, + flush=True, + ) print( "\nRELEASED: Telemetry client initialized: %s", session_id_hex, flush=True, ) except Exception as e: + print( + "\nERROR: Failed to initialize telemetry client: %s due to %s", + session_id_hex, + e, + flush=True, + ) logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() From 44668211e2274653ba61275644dc5600fac69f00 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 19 Jun 2025 17:09:54 +0530 Subject: [PATCH 62/86] thread name Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 154 +++++++++--------- 1 file changed, 77 insertions(+), 77 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10871f310..3adee8f76 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -31,79 +31,79 @@ logger = logging.getLogger(__name__) -# class DebugLock: -# """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" - -# def __init__(self, name: str = "DebugLock"): -# self._lock = threading.Lock() -# self._name = name -# self._owner: Optional[str] = None -# self._waiters: List[str] = [] -# self._debug_logger = logging.getLogger(f"{__name__}.{name}") -# # Ensure debug logging is visible -# if not self._debug_logger.handlers: -# handler = logging.StreamHandler() -# formatter = logging.Formatter( -# ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" -# ) -# handler.setFormatter(formatter) -# self._debug_logger.addHandler(handler) -# self._debug_logger.setLevel(logging.DEBUG) - -# def acquire(self, blocking=True, timeout=-1): -# current = threading.current_thread() -# thread_info = f"{current.name}-{current.ident}" -# if self._owner: -# self._debug_logger.warning( -# f": WAITING: {thread_info} waiting for lock held by {self._owner}" -# ) -# self._waiters.append(thread_info) -# else: -# self._debug_logger.debug( -# f": TRYING: {thread_info} attempting to acquire lock" -# ) -# # Try to acquire the lock -# acquired = self._lock.acquire(blocking, timeout) -# if acquired: -# self._owner = thread_info -# self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") -# if self._waiters: -# self._debug_logger.info( -# f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" -# ) -# else: -# self._debug_logger.error( -# f": FAILED: {thread_info} failed to acquire lock (timeout)" -# ) -# if thread_info in self._waiters: -# self._waiters.remove(thread_info) -# return acquired - -# def release(self): -# current = threading.current_thread() -# thread_info = f"{current.name}-{current.ident}" -# if self._owner != thread_info: -# self._debug_logger.error( -# f": ERROR: {thread_info} trying to release lock owned by {self._owner}" -# ) -# else: -# self._debug_logger.info(f": RELEASED: {thread_info} released the lock") -# self._owner = None -# # Remove from waiters if present -# if thread_info in self._waiters: -# self._waiters.remove(thread_info) -# if self._waiters: -# self._debug_logger.info( -# f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" -# ) -# self._lock.release() - -# def __enter__(self): -# self.acquire() -# return self - -# def __exit__(self, exc_type, exc_val, exc_tb): -# self.release() +class DebugLock: + """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" + + def __init__(self, name: str = "DebugLock"): + self._lock = threading.Lock() + self._name = name + self._owner: Optional[str] = None + self._waiters: List[str] = [] + self._debug_logger = logging.getLogger(f"{__name__}.{name}") + # Ensure debug logging is visible + if not self._debug_logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" + ) + handler.setFormatter(formatter) + self._debug_logger.addHandler(handler) + self._debug_logger.setLevel(logging.DEBUG) + + def acquire(self, blocking=True, timeout=-1): + current = threading.current_thread() + thread_info = f"{current.name}-{current.ident}" + if self._owner: + self._debug_logger.warning( + f": WAITING: {thread_info} waiting for lock held by {self._owner}" + ) + self._waiters.append(thread_info) + else: + self._debug_logger.debug( + f": TRYING: {thread_info} attempting to acquire lock" + ) + # Try to acquire the lock + acquired = self._lock.acquire(blocking, timeout) + if acquired: + self._owner = thread_info + self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") + if self._waiters: + self._debug_logger.info( + f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" + ) + else: + self._debug_logger.error( + f": FAILED: {thread_info} failed to acquire lock (timeout)" + ) + if thread_info in self._waiters: + self._waiters.remove(thread_info) + return acquired + + def release(self): + current = threading.current_thread() + thread_info = f"{current.name}-{current.ident}" + if self._owner != thread_info: + self._debug_logger.error( + f": ERROR: {thread_info} trying to release lock owned by {self._owner}" + ) + else: + self._debug_logger.info(f": RELEASED: {thread_info} released the lock") + self._owner = None + # Remove from waiters if present + if thread_info in self._waiters: + self._waiters.remove(thread_info) + if self._waiters: + self._debug_logger.info( + f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" + ) + self._lock.release() + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() class TelemetryHelper: @@ -430,10 +430,10 @@ class TelemetryClientFactory: ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False - _lock = threading.Lock() # Thread safety for factory operations - # _lock = DebugLock( - # "TelemetryClientFactory" - # ) # Thread safety for factory operations with debugging + # _lock = threading.Lock() # Thread safety for factory operations + _lock = DebugLock( + "TelemetryClientFactory" + ) # Thread safety for factory operations with debugging _original_excepthook = None _excepthook_installed = False From 34b63e4fa09edc26babb35885e8407ce9f61641f Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 20 Jun 2025 10:13:48 +0530 Subject: [PATCH 63/86] added teardown to all tests except finalizer test (gc collect) Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 2 +- .../sql/telemetry/telemetry_client.py | 154 +++++++++--------- tests/unit/test_client.py | 38 +++++ 3 files changed, 116 insertions(+), 78 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c085f8f5e..cc87ca522 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -446,7 +446,7 @@ def _close(self, close_cursors=True) -> None: if close_cursors: for cursor in self._cursors: cursor.close() - + print(f"Closing session {self.get_session_id_hex()}") logger.info(f"Closing session {self.get_session_id_hex()}") if not self.open: logger.debug("Session appears to have been closed already") diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 3adee8f76..10871f310 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -31,79 +31,79 @@ logger = logging.getLogger(__name__) -class DebugLock: - """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" - - def __init__(self, name: str = "DebugLock"): - self._lock = threading.Lock() - self._name = name - self._owner: Optional[str] = None - self._waiters: List[str] = [] - self._debug_logger = logging.getLogger(f"{__name__}.{name}") - # Ensure debug logging is visible - if not self._debug_logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter( - ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" - ) - handler.setFormatter(formatter) - self._debug_logger.addHandler(handler) - self._debug_logger.setLevel(logging.DEBUG) - - def acquire(self, blocking=True, timeout=-1): - current = threading.current_thread() - thread_info = f"{current.name}-{current.ident}" - if self._owner: - self._debug_logger.warning( - f": WAITING: {thread_info} waiting for lock held by {self._owner}" - ) - self._waiters.append(thread_info) - else: - self._debug_logger.debug( - f": TRYING: {thread_info} attempting to acquire lock" - ) - # Try to acquire the lock - acquired = self._lock.acquire(blocking, timeout) - if acquired: - self._owner = thread_info - self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") - if self._waiters: - self._debug_logger.info( - f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" - ) - else: - self._debug_logger.error( - f": FAILED: {thread_info} failed to acquire lock (timeout)" - ) - if thread_info in self._waiters: - self._waiters.remove(thread_info) - return acquired - - def release(self): - current = threading.current_thread() - thread_info = f"{current.name}-{current.ident}" - if self._owner != thread_info: - self._debug_logger.error( - f": ERROR: {thread_info} trying to release lock owned by {self._owner}" - ) - else: - self._debug_logger.info(f": RELEASED: {thread_info} released the lock") - self._owner = None - # Remove from waiters if present - if thread_info in self._waiters: - self._waiters.remove(thread_info) - if self._waiters: - self._debug_logger.info( - f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" - ) - self._lock.release() - - def __enter__(self): - self.acquire() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.release() +# class DebugLock: +# """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" + +# def __init__(self, name: str = "DebugLock"): +# self._lock = threading.Lock() +# self._name = name +# self._owner: Optional[str] = None +# self._waiters: List[str] = [] +# self._debug_logger = logging.getLogger(f"{__name__}.{name}") +# # Ensure debug logging is visible +# if not self._debug_logger.handlers: +# handler = logging.StreamHandler() +# formatter = logging.Formatter( +# ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" +# ) +# handler.setFormatter(formatter) +# self._debug_logger.addHandler(handler) +# self._debug_logger.setLevel(logging.DEBUG) + +# def acquire(self, blocking=True, timeout=-1): +# current = threading.current_thread() +# thread_info = f"{current.name}-{current.ident}" +# if self._owner: +# self._debug_logger.warning( +# f": WAITING: {thread_info} waiting for lock held by {self._owner}" +# ) +# self._waiters.append(thread_info) +# else: +# self._debug_logger.debug( +# f": TRYING: {thread_info} attempting to acquire lock" +# ) +# # Try to acquire the lock +# acquired = self._lock.acquire(blocking, timeout) +# if acquired: +# self._owner = thread_info +# self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") +# if self._waiters: +# self._debug_logger.info( +# f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" +# ) +# else: +# self._debug_logger.error( +# f": FAILED: {thread_info} failed to acquire lock (timeout)" +# ) +# if thread_info in self._waiters: +# self._waiters.remove(thread_info) +# return acquired + +# def release(self): +# current = threading.current_thread() +# thread_info = f"{current.name}-{current.ident}" +# if self._owner != thread_info: +# self._debug_logger.error( +# f": ERROR: {thread_info} trying to release lock owned by {self._owner}" +# ) +# else: +# self._debug_logger.info(f": RELEASED: {thread_info} released the lock") +# self._owner = None +# # Remove from waiters if present +# if thread_info in self._waiters: +# self._waiters.remove(thread_info) +# if self._waiters: +# self._debug_logger.info( +# f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" +# ) +# self._lock.release() + +# def __enter__(self): +# self.acquire() +# return self + +# def __exit__(self, exc_type, exc_val, exc_tb): +# self.release() class TelemetryHelper: @@ -430,10 +430,10 @@ class TelemetryClientFactory: ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False - # _lock = threading.Lock() # Thread safety for factory operations - _lock = DebugLock( - "TelemetryClientFactory" - ) # Thread safety for factory operations with debugging + _lock = threading.Lock() # Thread safety for factory operations + # _lock = DebugLock( + # "TelemetryClientFactory" + # ) # Thread safety for factory operations with debugging _original_excepthook = None _excepthook_installed = False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 588b0d70e..c77d69f27 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -83,6 +83,44 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } + def setUp(self): + """Set up connection tracking before each test""" + self.tracked_connections = [] + + # Store the original connect function + self.original_connect = databricks.sql.connect + + def patched_connect(*args, **kwargs): + """Wrapper that tracks all created connections""" + conn = self.original_connect(*args, **kwargs) + + # Skip tracking for finalizer tests to allow garbage collection + if not (hasattr(self, '_testMethodName') and 'finalizer' in self._testMethodName): + self.tracked_connections.append(conn) + + return conn + + # Apply the patch to track connections + self.connect_patcher = patch('databricks.sql.connect', patched_connect) + self.connect_patcher.start() + + def tearDown(self): + """Clean up connections after each test""" + # Close all tracked connections + for conn in self.tracked_connections: + try: + if hasattr(conn, 'open') and conn.open: + conn.close() + except Exception as e: + # Log the error but don't fail the test + print(f"Warning: Error closing connection in tearDown: {e}") + + # Stop the connect patcher + self.connect_patcher.stop() + + # Clear the tracked connections list + self.tracked_connections.clear() + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value From 49082fb82d74e461a62bd7387c4ba8de90614880 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 20 Jun 2025 14:38:24 +0530 Subject: [PATCH 64/86] added the get_attribute functions to the classes Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- src/databricks/sql/client.py | 60 +++++++++ .../sql/telemetry/latency_logger.py | 124 +++--------------- src/databricks/sql/thrift_backend.py | 5 +- 4 files changed, 80 insertions(+), 111 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 244118195..df6a0e169 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s + run: poetry run python -m pytest tests/unit -v -s check-linting: runs-on: ubuntu-latest strategy: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index cc87ca522..955ca0861 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -31,6 +31,8 @@ transform_paramstyle, ColumnTable, ColumnQueue, + ArrowQueue, + CloudFetchQueue, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -62,6 +64,7 @@ HostDetails, ) from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType logger = logging.getLogger(__name__) @@ -1355,6 +1358,35 @@ def setoutputsize(self, size, column=None): """Does nothing by default""" pass + def get_statement_id(self) -> Optional[str]: + return self.query_id + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.connection.lz4_compression + + def get_execution_result(self) -> ExecutionResultFormat: + if self.active_result_set is None: + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + if isinstance(self.active_result_set.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.active_result_set.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.active_result_set.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + # return len(self.thrift_backend.retry_policy.history) + return 0 + + def get_statement_type(self, func_name: str) -> StatementType: + # TODO: Implement this + return StatementType.SQL + class ResultSet: def __init__( @@ -1654,3 +1686,31 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + def get_statement_id(self) -> Optional[str]: + if self.command_id: + return str(UUID(bytes=self.command_id.operationId.guid)) + return None + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.lz4_compressed + + def get_execution_result(self) -> ExecutionResultFormat: + if isinstance(self.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_statement_type(self, func_name: str) -> StatementType: + # TODO: Implement this + return StatementType.SQL + + def get_retry_count(self) -> int: + # return len(self.thrift_backend.retry_policy.history) + return 0 diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 955f0a6e1..75460717d 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,91 +1,9 @@ import time import functools -from typing import Optional -from uuid import UUID from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( SqlExecutionEvent, ) -from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType -from databricks.sql.utils import ColumnQueue - -# Helper to get statement_id/query_id from instance if available -def _get_statement_id(instance) -> Optional[str]: - """ - Get statement ID from an instance using various methods: - 1. For Cursor: Use query_id property which returns UUID from active_op_handle - 2. For ResultSet: Use command_id which contains operationId - - Note: ThriftBackend itself doesn't have a statement ID since one backend - can handle multiple concurrent operations/cursors. - """ - if hasattr(instance, "query_id"): - return instance.query_id - - if hasattr(instance, "command_id") and instance.command_id: - return str(UUID(bytes=instance.command_id.operationId.guid)) - - return None - - -def _get_session_id_hex(instance) -> Optional[str]: - if hasattr(instance, "connection") and instance.connection: - return instance.connection.get_session_id_hex() - if hasattr(instance, "get_session_id_hex"): - return instance.get_session_id_hex() - return None - - -def _get_statement_type(func_name: str) -> StatementType: # TODO: implement this - return StatementType.SQL - - -def _get_is_compressed(instance) -> bool: - """ - Get compression status from instance: - 1. Direct lz4_compression attribute (Connection) - 2. Through connection attribute (Cursor/ResultSet) - 3. Through thrift_backend attribute (Cursor) - """ - if hasattr(instance, "lz4_compression"): - return instance.lz4_compression - if hasattr(instance, "connection") and instance.connection: - return instance.connection.lz4_compression - if hasattr(instance, "thrift_backend") and instance.thrift_backend: - return instance.thrift_backend.lz4_compressed - return False - - -def _get_execution_result(instance) -> ExecutionResultFormat: - """ - Get execution result format from instance: - 1. For ResultSet: Check if using cloud fetch (external_links) or arrow/columnar format - 2. For Cursor: Check through active_result_set - 3. For ThriftBackend: Check result format from server - """ - if hasattr(instance, "_use_cloud_fetch") and instance._use_cloud_fetch: - return ExecutionResultFormat.EXTERNAL_LINKS - - if hasattr(instance, "active_result_set") and instance.active_result_set: - if isinstance(instance.active_result_set.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - return ExecutionResultFormat.INLINE_ARROW - - if hasattr(instance, "thrift_backend") and instance.thrift_backend: - if hasattr(instance.thrift_backend, "_use_arrow_native_complex_types"): - return ExecutionResultFormat.INLINE_ARROW - - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - -def _get_retry_count(instance) -> int: - """ - Get retry count from instance by checking retry_policy.history length. - The retry_policy is only accessible through thrift_backend. - """ - # TODO: implement this - - return 0 def log_latency(): @@ -101,30 +19,24 @@ def wrapper(self, *args, **kwargs): end_time = time.perf_counter() duration_ms = int((end_time - start_time) * 1000) - session_id_hex = _get_session_id_hex(self) - - if session_id_hex: - statement_id = _get_statement_id(self) - statement_type = _get_statement_type(func.__name__) - is_compressed = _get_is_compressed(self) - execution_result = _get_execution_result(self) - retry_count = _get_retry_count(self) - - sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, - is_compressed=is_compressed, - execution_result=execution_result, - retry_count=retry_count, - ) - - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_latency_log( - latency_ms=duration_ms, - sql_execution_event=sql_exec_event, - sql_statement_id=statement_id, - ) + session_id_hex = self.get_session_id_hex() + statement_id = self.get_statement_id() + + sql_exec_event = SqlExecutionEvent( + statement_type=self.get_statement_type(func.__name__), + is_compressed=self.get_is_compressed(), + execution_result=self.get_execution_result(), + retry_count=self.get_retry_count(), + ) + + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=statement_id, + ) return wrapper diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index d7f9bdd06..0cd5fb124 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -41,7 +41,7 @@ convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions -from typing import Optional + logger = logging.getLogger(__name__) @@ -584,9 +584,6 @@ def open_session(self, session_configuration, catalog, schema): self._transport.close() raise - def get_session_id_hex(self) -> Optional[str]: - return self._session_id_hex - def close_session(self, session_handle) -> None: req = ttypes.TCloseSessionReq(sessionHandle=session_handle) try: From ed1db9d79e2d3ec628e3c1278ed3349d12c454fe Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 20 Jun 2025 14:46:32 +0530 Subject: [PATCH 65/86] removed tearDown, added connection.close() to first test Signed-off-by: Sai Shree Pradhan --- tests/unit/test_client.py | 59 ++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c77d69f27..4610928d2 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -83,43 +83,43 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - def setUp(self): - """Set up connection tracking before each test""" - self.tracked_connections = [] + # def setUp(self): + # """Set up connection tracking before each test""" + # self.tracked_connections = [] - # Store the original connect function - self.original_connect = databricks.sql.connect + # # Store the original connect function + # self.original_connect = databricks.sql.connect - def patched_connect(*args, **kwargs): - """Wrapper that tracks all created connections""" - conn = self.original_connect(*args, **kwargs) + # def patched_connect(*args, **kwargs): + # """Wrapper that tracks all created connections""" + # conn = self.original_connect(*args, **kwargs) - # Skip tracking for finalizer tests to allow garbage collection - if not (hasattr(self, '_testMethodName') and 'finalizer' in self._testMethodName): - self.tracked_connections.append(conn) + # # Skip tracking for finalizer tests to allow garbage collection + # if not (hasattr(self, '_testMethodName') and 'finalizer' in self._testMethodName): + # self.tracked_connections.append(conn) - return conn + # return conn - # Apply the patch to track connections - self.connect_patcher = patch('databricks.sql.connect', patched_connect) - self.connect_patcher.start() + # # Apply the patch to track connections + # self.connect_patcher = patch('databricks.sql.connect', patched_connect) + # self.connect_patcher.start() - def tearDown(self): - """Clean up connections after each test""" - # Close all tracked connections - for conn in self.tracked_connections: - try: - if hasattr(conn, 'open') and conn.open: - conn.close() - except Exception as e: - # Log the error but don't fail the test - print(f"Warning: Error closing connection in tearDown: {e}") + # def tearDown(self): + # """Clean up connections after each test""" + # # Close all tracked connections + # for conn in self.tracked_connections: + # try: + # if hasattr(conn, 'open') and conn.open: + # conn.close() + # except Exception as e: + # # Log the error but don't fail the test + # print(f"Warning: Error closing connection in tearDown: {e}") - # Stop the connect patcher - self.connect_patcher.stop() + # # Stop the connect patcher + # self.connect_patcher.stop() - # Clear the tracked connections list - self.tracked_connections.clear() + # # Clear the tracked connections list + # self.tracked_connections.clear() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): @@ -790,6 +790,7 @@ def test_access_current_query_id(self): cursor.close() self.assertIsNone(cursor.query_id) + connection.close() def test_cursor_close_handles_exception(self): """Test that Cursor.close() handles exceptions from close_command properly.""" From 9fa5a89b2551d9cdc12a9f7aff3da42967c345bc Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sat, 21 Jun 2025 12:19:53 +0530 Subject: [PATCH 66/86] finally Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- src/databricks/sql/client.py | 2 +- .../sql/telemetry/telemetry_client.py | 148 +----------------- 3 files changed, 3 insertions(+), 149 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index df6a0e169..ba446f878 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit -v -s + run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest strategy: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 955ca0861..f232ba427 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -449,7 +449,7 @@ def _close(self, close_cursors=True) -> None: if close_cursors: for cursor in self._cursors: cursor.close() - print(f"Closing session {self.get_session_id_hex()}") + logger.info(f"Closing session {self.get_session_id_hex()}") if not self.open: logger.debug("Session appears to have been closed already") diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10871f310..e3e71ba7f 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -31,81 +31,6 @@ logger = logging.getLogger(__name__) -# class DebugLock: -# """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release""" - -# def __init__(self, name: str = "DebugLock"): -# self._lock = threading.Lock() -# self._name = name -# self._owner: Optional[str] = None -# self._waiters: List[str] = [] -# self._debug_logger = logging.getLogger(f"{__name__}.{name}") -# # Ensure debug logging is visible -# if not self._debug_logger.handlers: -# handler = logging.StreamHandler() -# formatter = logging.Formatter( -# ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s" -# ) -# handler.setFormatter(formatter) -# self._debug_logger.addHandler(handler) -# self._debug_logger.setLevel(logging.DEBUG) - -# def acquire(self, blocking=True, timeout=-1): -# current = threading.current_thread() -# thread_info = f"{current.name}-{current.ident}" -# if self._owner: -# self._debug_logger.warning( -# f": WAITING: {thread_info} waiting for lock held by {self._owner}" -# ) -# self._waiters.append(thread_info) -# else: -# self._debug_logger.debug( -# f": TRYING: {thread_info} attempting to acquire lock" -# ) -# # Try to acquire the lock -# acquired = self._lock.acquire(blocking, timeout) -# if acquired: -# self._owner = thread_info -# self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock") -# if self._waiters: -# self._debug_logger.info( -# f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}" -# ) -# else: -# self._debug_logger.error( -# f": FAILED: {thread_info} failed to acquire lock (timeout)" -# ) -# if thread_info in self._waiters: -# self._waiters.remove(thread_info) -# return acquired - -# def release(self): -# current = threading.current_thread() -# thread_info = f"{current.name}-{current.ident}" -# if self._owner != thread_info: -# self._debug_logger.error( -# f": ERROR: {thread_info} trying to release lock owned by {self._owner}" -# ) -# else: -# self._debug_logger.info(f": RELEASED: {thread_info} released the lock") -# self._owner = None -# # Remove from waiters if present -# if thread_info in self._waiters: -# self._waiters.remove(thread_info) -# if self._waiters: -# self._debug_logger.info( -# f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}" -# ) -# self._lock.release() - -# def __enter__(self): -# self.acquire() -# return self - -# def __exit__(self, exc_type, exc_val, exc_tb): -# self.release() - - class TelemetryHelper: """Helper class for getting telemetry related information.""" @@ -430,10 +355,7 @@ class TelemetryClientFactory: ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False - _lock = threading.Lock() # Thread safety for factory operations - # _lock = DebugLock( - # "TelemetryClientFactory" - # ) # Thread safety for factory operations with debugging + _lock = threading.RLock() # Thread safety for factory operations _original_excepthook = None _excepthook_installed = False @@ -465,7 +387,6 @@ def _install_exception_hook(cls): def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) - print("Handling unhandled exception: %s", exc_type.__name__) clients_to_close = list(cls._clients.values()) for client in clients_to_close: client.close() @@ -483,36 +404,15 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - print( - "\nWAITING: Initializing telemetry client: %s", - session_id_hex, - flush=True, - ) with TelemetryClientFactory._lock: - print( - "\nACQUIRED: Initializing telemetry client, got lock: %s", - session_id_hex, - flush=True, - ) TelemetryClientFactory._initialize() - print( - "\n TelemetryClientFactory initialized: %s", - session_id_hex, - flush=True, - ) if session_id_hex not in TelemetryClientFactory._clients: - print( - "\n Session ID not in clients: %s", - session_id_hex, - flush=True, - ) logger.debug( "Creating new TelemetryClient for connection %s", session_id_hex, ) if telemetry_enabled: - print("\n Telemetry enabled: %s", session_id_hex, flush=True) TelemetryClientFactory._clients[ session_id_hex ] = TelemetryClient( @@ -522,41 +422,11 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, ) - print( - "\n Telemetry client initialized: %s", - session_id_hex, - flush=True, - ) else: - print( - "\n Telemetry disabled: %s", session_id_hex, flush=True - ) TelemetryClientFactory._clients[ session_id_hex ] = NoopTelemetryClient() - print( - "\n Noop Telemetry client initialized: %s", - session_id_hex, - flush=True, - ) - else: - print( - "\n Session ID already in clients: %s", - session_id_hex, - flush=True, - ) - print( - "\nRELEASED: Telemetry client initialized: %s", - session_id_hex, - flush=True, - ) except Exception as e: - print( - "\nERROR: Failed to initialize telemetry client: %s due to %s", - session_id_hex, - e, - flush=True, - ) logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() @@ -580,13 +450,7 @@ def get_telemetry_client(session_id_hex): @staticmethod def close(session_id_hex): """Close and remove the telemetry client for a specific connection""" - print("\nWAITING: Closing telemetry client: %s", session_id_hex, flush=True) with TelemetryClientFactory._lock: - print( - "\nACQUIRED: Closing telemetry client, got lock: %s", - session_id_hex, - flush=True, - ) if ( telemetry_client := TelemetryClientFactory._clients.pop( session_id_hex, None @@ -602,16 +466,6 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - print( - "\nSHUTDOWN: Shutting down thread pool executor: %s", - session_id_hex, - flush=True, - ) TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - print( - "\nRELEASED: Thread pool executor shut down: %s", - session_id_hex, - flush=True, - ) From 14433c4138fdb2527285c7160c065d117707d9fb Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Sun, 22 Jun 2025 16:48:51 +0530 Subject: [PATCH 67/86] remove debugging Signed-off-by: Sai Shree Pradhan --- .github/workflows/code-quality-checks.yml | 2 +- .../sql/telemetry/telemetry_client.py | 3 ++ src/databricks/sql/thrift_backend.py | 1 - tests/unit/test_client.py | 39 ------------------- 4 files changed, 4 insertions(+), 41 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index ba446f878..462d22369 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -112,7 +112,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run python -m pytest tests/unit + run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest strategy: diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index e3e71ba7f..64b43227c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -387,6 +387,7 @@ def _install_exception_hook(cls): def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback): """Handle unhandled exceptions by sending telemetry and flushing thread pool""" logger.debug("Handling unhandled exception: %s", exc_type.__name__) + clients_to_close = list(cls._clients.values()) for client in clients_to_close: client.close() @@ -403,6 +404,7 @@ def initialize_telemetry_client( host_url, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" + try: with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() @@ -450,6 +452,7 @@ def get_telemetry_client(session_id_hex): @staticmethod def close(session_id_hex): """Close and remove the telemetry client for a specific connection""" + with TelemetryClientFactory._lock: if ( telemetry_client := TelemetryClientFactory._clients.pop( diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 0cd5fb124..78683ac31 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -42,7 +42,6 @@ ) from databricks.sql.types import SSLOptions - logger = logging.getLogger(__name__) unsafe_logger = logging.getLogger("databricks.sql.unsafe") diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4610928d2..588b0d70e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -83,44 +83,6 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - # def setUp(self): - # """Set up connection tracking before each test""" - # self.tracked_connections = [] - - # # Store the original connect function - # self.original_connect = databricks.sql.connect - - # def patched_connect(*args, **kwargs): - # """Wrapper that tracks all created connections""" - # conn = self.original_connect(*args, **kwargs) - - # # Skip tracking for finalizer tests to allow garbage collection - # if not (hasattr(self, '_testMethodName') and 'finalizer' in self._testMethodName): - # self.tracked_connections.append(conn) - - # return conn - - # # Apply the patch to track connections - # self.connect_patcher = patch('databricks.sql.connect', patched_connect) - # self.connect_patcher.start() - - # def tearDown(self): - # """Clean up connections after each test""" - # # Close all tracked connections - # for conn in self.tracked_connections: - # try: - # if hasattr(conn, 'open') and conn.open: - # conn.close() - # except Exception as e: - # # Log the error but don't fail the test - # print(f"Warning: Error closing connection in tearDown: {e}") - - # # Stop the connect patcher - # self.connect_patcher.stop() - - # # Clear the tracked connections list - # self.tracked_connections.clear() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value @@ -790,7 +752,6 @@ def test_access_current_query_id(self): cursor.close() self.assertIsNone(cursor.query_id) - connection.close() def test_cursor_close_handles_exception(self): """Test that Cursor.close() handles exceptions from close_command properly.""" From ef4ca1349959b3e40959e02397845ee028368fba Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 23 Jun 2025 12:36:05 +0530 Subject: [PATCH 68/86] added test for export_latency_log, made mock of thrift backend with retry policy Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 16 ++++--- .../sql/telemetry/telemetry_client.py | 8 +--- tests/unit/test_client.py | 22 +++++++-- tests/unit/test_fetches.py | 13 +++++- tests/unit/test_telemetry.py | 45 +++++++++++++++++++ 5 files changed, 86 insertions(+), 18 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f232ba427..5351cb031 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1185,7 +1185,6 @@ def columns( ) return self - @log_latency() def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a sequence of sequences. @@ -1219,7 +1218,6 @@ def fetchone(self) -> Optional[Row]: session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() def fetchmany(self, size: int) -> List[Row]: """ Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a @@ -1245,7 +1243,6 @@ def fetchmany(self, size: int) -> List[Row]: session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: @@ -1256,7 +1253,6 @@ def fetchall_arrow(self) -> "pyarrow.Table": session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: @@ -1380,7 +1376,11 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.FORMAT_UNSPECIFIED def get_retry_count(self) -> int: - # return len(self.thrift_backend.retry_policy.history) + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) return 0 def get_statement_type(self, func_name: str) -> StatementType: @@ -1712,5 +1712,9 @@ def get_statement_type(self, func_name: str) -> StatementType: return StatementType.SQL def get_retry_count(self) -> int: - # return len(self.thrift_backend.retry_policy.history) + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) return 0 diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 64b43227c..7134f11a2 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -113,9 +113,7 @@ def export_failure_log(self, error_name, error_message): raise NotImplementedError("Subclasses must implement export_failure_log") @abstractmethod - def export_latency_log( - self, latency_ms, sql_execution_event, sql_statement_id=None - ): + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): raise NotImplementedError("Subclasses must implement export_latency_log") @abstractmethod @@ -310,9 +308,7 @@ def export_failure_log(self, error_name, error_message): except Exception as e: logger.debug("Failed to export failure log: %s", e) - def export_latency_log( - self, latency_ms, sql_execution_event, sql_statement_id=None - ): + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): logger.debug("Exporting latency log for connection %s", self._session_id_hex) try: telemetry_frontend_log = TelemetryFrontendLog( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 588b0d70e..a53e4d49c 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -39,6 +39,11 @@ def new(cls): cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + # Mock retry_policy with history attribute + mock_retry_policy = Mock() + mock_retry_policy.history = [] + cls.apply_property_to_mock(ThriftBackendMock, retry_policy=mock_retry_policy) + cls.apply_property_to_mock( MockTExecuteStatementResp, description=None, @@ -70,6 +75,15 @@ def apply_property_to_mock(self, mock_obj, **kwargs): prop = PropertyMock(**kwargs) setattr(type(mock_obj), key, prop) + @classmethod + def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator + """Create a simple thrift_backend mock with retry_policy for basic tests.""" + mock_thrift_backend = Mock() + mock_retry_policy = Mock() + mock_retry_policy.history = [] + mock_thrift_backend.retry_policy = mock_retry_policy + return mock_thrift_backend + class ClientTestSuite(unittest.TestCase): """ @@ -319,7 +333,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command( mock_result_sets[1].fetchall.assert_called_once_with() def test_closed_cursor_doesnt_allow_operations(self): - cursor = client.Cursor(Mock(), Mock()) + cursor = client.Cursor(Mock(), ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()) cursor.close() with self.assertRaises(Error) as e: @@ -399,7 +413,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = Mock() + mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.schemas(**req_args) @@ -422,7 +436,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = Mock() + mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.tables(**req_args) @@ -445,7 +459,7 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = Mock() + mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.columns(**req_args) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 71766f2cb..43af1d361 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -31,6 +31,15 @@ def make_arrow_queue(batch): _, table = FetchTests.make_arrow_table(batch) queue = ArrowQueue(table, len(batch)) return queue + + @classmethod + def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator + """Create a simple thrift_backend mock with retry_policy for basic tests.""" + mock_thrift_backend = Mock() + mock_retry_policy = Mock() + mock_retry_policy.history = [] + mock_thrift_backend.retry_policy = mock_retry_policy + return mock_thrift_backend @staticmethod def make_dummy_result_set_from_initial_results(initial_results): @@ -39,7 +48,7 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) rs = client.ResultSet( connection=Mock(), - thrift_backend=None, + thrift_backend=FetchTests.mock_thrift_backend_with_retry_policy(), execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, @@ -79,7 +88,7 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = Mock() + mock_thrift_backend = FetchTests.mock_thrift_backend_with_retry_policy() mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bbe..577fe3f5c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -96,6 +96,12 @@ def test_export_failure_log(self, noop_telemetry_client): error_name="TestError", error_message="Test error message" ) + def test_export_latency_log(self, noop_telemetry_client): + """Test that export_latency_log does nothing.""" + noop_telemetry_client.export_latency_log( + latency_ms=100, sql_execution_event="EXECUTE_STATEMENT", sql_statement_id="test-id" + ) + def test_close(self, noop_telemetry_client): """Test that close does nothing.""" noop_telemetry_client.close() @@ -181,6 +187,40 @@ def test_export_failure_log( client._export_event.assert_called_once_with(mock_frontend_log.return_value) + @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") + @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") + @patch("databricks.sql.telemetry.telemetry_client.time.time") + def test_export_latency_log( + self, + mock_time, + mock_uuid4, + mock_get_driver_config, + mock_frontend_log, + telemetry_client_setup + ): + """Test exporting latency telemetry log.""" + mock_time.return_value = 3000 + mock_uuid4.return_value = "test-latency-uuid" + mock_get_driver_config.return_value = "test-driver-config" + mock_frontend_log.return_value = MagicMock() + + client = telemetry_client_setup["client"] + client._export_event = MagicMock() + + client._driver_connection_params = "test-connection-params" + client._user_agent = "test-user-agent" + + latency_ms = 150 + sql_execution_event = "test-execution-event" + sql_statement_id = "test-statement-id" + + client.export_latency_log(latency_ms, sql_execution_event, sql_statement_id) + + mock_frontend_log.assert_called_once() + + client._export_event.assert_called_once_with(mock_frontend_log.return_value) + def test_export_event(self, telemetry_client_setup): """Test exporting an event.""" client = telemetry_client_setup["client"] @@ -311,6 +351,11 @@ def test_telemetry_client_exception_handling(self, telemetry_client_setup): # Should not raise exception client.export_failure_log("TestError", "Test error message") + # Test export_latency_log with exception + with patch.object(client, '_export_event', side_effect=Exception("Test error")): + # Should not raise exception + client.export_latency_log(100, "EXECUTE_STATEMENT", "test-statement-id") + # Test _send_telemetry with exception with patch.object(client._executor, 'submit', side_effect=Exception("Test error")): # Should not raise exception From b5bf1656056c55d5d814f5f12f9400bc46ed7441 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 23 Jun 2025 16:41:00 +0530 Subject: [PATCH 69/86] added multi threaded tests Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 8 +- tests/unit/test_telemetry.py | 406 +++++++++++++++++- 2 files changed, 406 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 667c5aaac..5f4b6f079 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -158,6 +158,7 @@ class TelemetryClient(BaseTelemetryClient): # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" + DEFAULT_BATCH_SIZE = 10 def __init__( self, @@ -169,7 +170,7 @@ def __init__( ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = 10 # TODO: Decide on batch size + self._batch_size = self.DEFAULT_BATCH_SIZE # TODO: Decide on batch size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -466,6 +467,9 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - TelemetryClientFactory._executor.shutdown(wait=True) + try: + TelemetryClientFactory._executor.shutdown(wait=True) + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 577fe3f5c..cd05b1565 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,6 +2,9 @@ import pytest import requests from unittest.mock import patch, MagicMock, call +import threading +import time +import random from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -226,17 +229,16 @@ def test_export_event(self, telemetry_client_setup): client = telemetry_client_setup["client"] client._flush = MagicMock() - for i in range(5): + for i in range(TelemetryClient.DEFAULT_BATCH_SIZE-1): client._export_event(f"event-{i}") client._flush.assert_not_called() - assert len(client._events_batch) == 5 + assert len(client._events_batch) == TelemetryClient.DEFAULT_BATCH_SIZE - 1 - for i in range(5, 10): - client._export_event(f"event-{i}") + # Add one more event to reach batch size (this will trigger flush) + client._export_event(f"event-{TelemetryClient.DEFAULT_BATCH_SIZE - 1}") client._flush.assert_called_once() - assert len(client._events_batch) == 10 @patch("requests.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): @@ -543,4 +545,396 @@ def test_global_exception_hook(self, mock_handle_exception, telemetry_system_res test_exception = ValueError("Test exception") TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) + + def test_initialize_telemetry_client_exception_handling(self, telemetry_system_reset): + """Test that exceptions in initialize_telemetry_client don't cause connector to fail.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Test exception during TelemetryClient creation + with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', side_effect=Exception("TelemetryClient creation failed")): + # Should not raise exception, should fallback to NoopTelemetryClient + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_get_telemetry_client_exception_handling(self, telemetry_system_reset): + """Test that exceptions in get_telemetry_client don't cause connector to fail.""" + session_id_hex = "test-uuid" + + # Test exception during client lookup by mocking the clients dict + mock_clients = MagicMock() + mock_clients.__contains__.side_effect = Exception("Client lookup failed") + + with patch.object(TelemetryClientFactory, '_clients', mock_clients): + # Should not raise exception, should return NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_get_telemetry_client_dict_access_exception(self, telemetry_system_reset): + """Test that exceptions during dictionary access don't cause connector to fail.""" + session_id_hex = "test-uuid" + + # Test exception during dictionary access + mock_clients = MagicMock() + mock_clients.__contains__.side_effect = Exception("Dictionary access failed") + TelemetryClientFactory._clients = mock_clients + + # Should not raise exception, should return NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_close_telemetry_client_shutdown_executor_exception(self, telemetry_system_reset): + """Test that exceptions during executor shutdown don't cause connector to fail.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Mock executor to raise exception during shutdown + mock_executor = MagicMock() + mock_executor.shutdown.side_effect = Exception("Executor shutdown failed") + TelemetryClientFactory._executor = mock_executor + + # Should not raise exception (executor shutdown is wrapped in try-catch) + TelemetryClientFactory.close(session_id_hex) + + # Verify executor shutdown was attempted + mock_executor.shutdown.assert_called_once_with(wait=True) + + +class TestTelemetryRaceConditions: + """Tests for race conditions in multithreaded scenarios.""" + + @pytest.fixture + def race_condition_setup(self): + """Setup for race condition tests.""" + # Reset telemetry system + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield + + # Cleanup + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_telemetry_client_concurrent_export_events(self, race_condition_setup): + """Test race conditions in TelemetryClient._export_event with concurrent access.""" + session_id_hex = "test-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + # Mock _flush to avoid actual network calls + client._flush = MagicMock() + + # Track events added by each thread + thread_events = {} + lock = threading.Lock() + + def add_events(thread_id): + """Add events from a specific thread.""" + events = [] + for i in range(10): + event = f"event-{thread_id}-{i}" + client._export_event(event) + events.append(event) + + with lock: + thread_events[thread_id] = events + + # Start multiple threads adding events concurrently + threads = [] + for i in range(5): + thread = threading.Thread(target=add_events, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all events were added (no data loss due to race conditions) + total_expected_events = sum(len(events) for events in thread_events.values()) + assert len(client._events_batch) == total_expected_events + + def test_telemetry_client_concurrent_flush_operations(self, race_condition_setup): + """Test race conditions in TelemetryClient._flush with concurrent access.""" + session_id_hex = "test-flush-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + # Mock _send_telemetry to avoid actual network calls + client._send_telemetry = MagicMock() + + # Add events to trigger flush + for i in range(TelemetryClient.DEFAULT_BATCH_SIZE - 1): + client._export_event(f"event-{i}") + + # Track flush operations + flush_count = 0 + flush_lock = threading.Lock() + + def concurrent_flush(): + """Call flush concurrently.""" + nonlocal flush_count + client._flush() + with flush_lock: + flush_count += 1 + + # Start multiple threads calling flush concurrently + threads = [] + for i in range(10): + thread = threading.Thread(target=concurrent_flush) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify flush was called the expected number of times + assert flush_count == 10 + + # Verify _send_telemetry was called at least once (some calls may have empty batches due to lock) + assert client._send_telemetry.call_count >= 1 + + # Verify that the total events processed is correct (no data loss) + # The first flush should have processed all events, subsequent flushes should have empty batches + total_events_sent = sum(len(call.args[0]) for call in client._send_telemetry.call_args_list) + assert total_events_sent == TelemetryClient.DEFAULT_BATCH_SIZE - 1 + + def test_telemetry_client_factory_concurrent_initialization(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.initialize_telemetry_client with concurrent access.""" + session_id_hex = "test-factory-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Track initialization attempts + init_results = [] + init_lock = threading.Lock() + + def concurrent_initialize(thread_id): + """Initialize telemetry client concurrently.""" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with init_lock: + init_results.append({ + 'thread_id': thread_id, + 'client_type': type(client).__name__ + }) + + # Start multiple threads initializing concurrently + threads = [] + for i in range(10): + thread = threading.Thread(target=concurrent_initialize, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + assert len(init_results) == 10 + + # Verify only one client was created (no duplicate clients due to race conditions) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, TelemetryClient) + + # Verify the client is the same for all threads (singleton behavior) + client_ids = set() + for result in init_results: + client_ids.add(id(TelemetryClientFactory.get_telemetry_client(session_id_hex))) + + assert len(client_ids) == 1 + + def test_telemetry_client_factory_concurrent_get_client(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.get_telemetry_client with concurrent access.""" + session_id_hex = "test-get-client-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Track get_client attempts + get_results = [] + get_lock = threading.Lock() + + def concurrent_get_client(thread_id): + """Get telemetry client concurrently.""" + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with get_lock: + get_results.append({ + 'thread_id': thread_id, + 'client_type': type(client).__name__, + 'client_id': id(client) + }) + + # Start multiple threads getting client concurrently + threads = [] + for i in range(20): + thread = threading.Thread(target=concurrent_get_client, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all get_client calls succeeded + assert len(get_results) == 20 + + # Verify all threads got the same client instance (no race conditions) + client_ids = set(result['client_id'] for result in get_results) + assert len(client_ids) == 1 # Only one client instance returned + + def test_telemetry_client_factory_concurrent_close(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.close with concurrent access.""" + session_id_hex = "test-close-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + def concurrent_close(thread_id): + """Close telemetry client concurrently.""" + + TelemetryClientFactory.close(session_id_hex) + + # Start multiple threads closing concurrently + threads = [] + for i in range(5): + thread = threading.Thread(target=concurrent_close, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify client is no longer available after close + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_telemetry_client_factory_mixed_concurrent_operations(self, race_condition_setup): + """Test race conditions with mixed concurrent operations on TelemetryClientFactory.""" + session_id_hex = "test-mixed-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Track operation results + operation_results = [] + operation_lock = threading.Lock() + + def mixed_operations(thread_id): + """Perform mixed operations concurrently.""" + + # Randomly choose an operation + operation = random.choice(['init', 'get', 'close']) + + if operation == 'init': + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'init', + 'client_type': type(client).__name__ + }) + + elif operation == 'get': + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'get', + 'client_type': type(client).__name__ + }) + + elif operation == 'close': + TelemetryClientFactory.close(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'close' + }) + + # Start multiple threads performing mixed operations + threads = [] + for i in range(15): + thread = threading.Thread(target=mixed_operations, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + assert len(operation_results) == 15 From 307a8cc57aacc8529d781f63226dc9d3b5fe249e Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 23 Jun 2025 16:43:05 +0530 Subject: [PATCH 70/86] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 5f4b6f079..846db448d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -158,7 +158,7 @@ class TelemetryClient(BaseTelemetryClient): # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" - DEFAULT_BATCH_SIZE = 10 + DEFAULT_BATCH_SIZE = 10 def __init__( self, From 0fd46d41f602f63846812ddf5a59389f4a594615 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 10:21:19 +0530 Subject: [PATCH 71/86] added TelemetryExtractor, removed multithreaded tests Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 67 --- .../sql/telemetry/latency_logger.py | 110 ++++- tests/unit/test_telemetry.py | 394 +----------------- 3 files changed, 105 insertions(+), 466 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5351cb031..50b709e8a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -31,8 +31,6 @@ transform_paramstyle, ColumnTable, ColumnQueue, - ArrowQueue, - CloudFetchQueue, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -64,7 +62,6 @@ HostDetails, ) from databricks.sql.telemetry.latency_logger import log_latency -from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType logger = logging.getLogger(__name__) @@ -1354,39 +1351,6 @@ def setoutputsize(self, size, column=None): """Does nothing by default""" pass - def get_statement_id(self) -> Optional[str]: - return self.query_id - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.connection.lz4_compression - - def get_execution_result(self) -> ExecutionResultFormat: - if self.active_result_set is None: - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - if isinstance(self.active_result_set.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.active_result_set.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.active_result_set.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) - return 0 - - def get_statement_type(self, func_name: str) -> StatementType: - # TODO: Implement this - return StatementType.SQL - class ResultSet: def __init__( @@ -1687,34 +1651,3 @@ def map_col_type(type_): for column in table_schema_message.columns ] - def get_statement_id(self) -> Optional[str]: - if self.command_id: - return str(UUID(bytes=self.command_id.operationId.guid)) - return None - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.lz4_compressed - - def get_execution_result(self) -> ExecutionResultFormat: - if isinstance(self.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_statement_type(self, func_name: str) -> StatementType: - # TODO: Implement this - return StatementType.SQL - - def get_retry_count(self) -> int: - if ( - hasattr(self.thrift_backend, "retry_policy") - and self.thrift_backend.retry_policy - ): - return len(self.thrift_backend.retry_policy.history) - return 0 diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 75460717d..0c0f883ce 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,9 +1,106 @@ import time import functools +from typing import Optional from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( SqlExecutionEvent, ) +from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType +from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue +from uuid import UUID + + +class TelemetryExtractor: + def __init__(self, obj): + self._obj = obj + + def __getattr__(self, name): + return getattr(self._obj, name) + + def get_session_id_hex(self): pass + def get_statement_id(self): pass + def get_statement_type(self): pass + def get_is_compressed(self): pass + def get_execution_result(self): pass + def get_retry_count(self): pass + + +class CursorExtractor(TelemetryExtractor): + def get_statement_id(self) -> Optional[str]: + return self.query_id + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.connection.lz4_compression + + def get_execution_result(self) -> ExecutionResultFormat: + if self.active_result_set is None: + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + if isinstance(self.active_result_set.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.active_result_set.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.active_result_set.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) + return 0 + + def get_statement_type(self: str) -> StatementType: + # TODO: Implement this + return StatementType.SQL + + +class ResultSetExtractor(TelemetryExtractor): + def get_statement_id(self) -> Optional[str]: + if self.command_id: + return str(UUID(bytes=self.command_id.operationId.guid)) + return None + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.lz4_compressed + + def get_execution_result(self) -> ExecutionResultFormat: + if isinstance(self.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_statement_type(self: str) -> StatementType: + # TODO: Implement this + return StatementType.SQL + + def get_retry_count(self) -> int: + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) + return 0 + + +def get_extractor(obj): + if obj.__class__.__name__ == 'Cursor': + return CursorExtractor(obj) + elif obj.__class__.__name__ == 'ResultSet': + return ResultSetExtractor(obj) + else: + return TelemetryExtractor(obj) def log_latency(): @@ -19,14 +116,15 @@ def wrapper(self, *args, **kwargs): end_time = time.perf_counter() duration_ms = int((end_time - start_time) * 1000) - session_id_hex = self.get_session_id_hex() - statement_id = self.get_statement_id() + extractor = get_extractor(self) + session_id_hex = extractor.get_session_id_hex() + statement_id = extractor.get_statement_id() sql_exec_event = SqlExecutionEvent( - statement_type=self.get_statement_type(func.__name__), - is_compressed=self.get_is_compressed(), - execution_result=self.get_execution_result(), - retry_count=self.get_retry_count(), + statement_type=extractor.get_statement_type(), + is_compressed=extractor.get_is_compressed(), + execution_result=extractor.get_execution_result(), + retry_count=extractor.get_retry_count(), ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index cd05b1565..8707f09de 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -545,396 +545,4 @@ def test_global_exception_hook(self, mock_handle_exception, telemetry_system_res test_exception = ValueError("Test exception") TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) - - def test_initialize_telemetry_client_exception_handling(self, telemetry_system_reset): - """Test that exceptions in initialize_telemetry_client don't cause connector to fail.""" - session_id_hex = "test-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - # Test exception during TelemetryClient creation - with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', side_effect=Exception("TelemetryClient creation failed")): - # Should not raise exception, should fallback to NoopTelemetryClient - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - - def test_get_telemetry_client_exception_handling(self, telemetry_system_reset): - """Test that exceptions in get_telemetry_client don't cause connector to fail.""" - session_id_hex = "test-uuid" - - # Test exception during client lookup by mocking the clients dict - mock_clients = MagicMock() - mock_clients.__contains__.side_effect = Exception("Client lookup failed") - - with patch.object(TelemetryClientFactory, '_clients', mock_clients): - # Should not raise exception, should return NoopTelemetryClient - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - - def test_get_telemetry_client_dict_access_exception(self, telemetry_system_reset): - """Test that exceptions during dictionary access don't cause connector to fail.""" - session_id_hex = "test-uuid" - - # Test exception during dictionary access - mock_clients = MagicMock() - mock_clients.__contains__.side_effect = Exception("Dictionary access failed") - TelemetryClientFactory._clients = mock_clients - - # Should not raise exception, should return NoopTelemetryClient - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - - def test_close_telemetry_client_shutdown_executor_exception(self, telemetry_system_reset): - """Test that exceptions during executor shutdown don't cause connector to fail.""" - session_id_hex = "test-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - # Initialize a client first - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - - # Mock executor to raise exception during shutdown - mock_executor = MagicMock() - mock_executor.shutdown.side_effect = Exception("Executor shutdown failed") - TelemetryClientFactory._executor = mock_executor - - # Should not raise exception (executor shutdown is wrapped in try-catch) - TelemetryClientFactory.close(session_id_hex) - - # Verify executor shutdown was attempted - mock_executor.shutdown.assert_called_once_with(wait=True) - - -class TestTelemetryRaceConditions: - """Tests for race conditions in multithreaded scenarios.""" - - @pytest.fixture - def race_condition_setup(self): - """Setup for race condition tests.""" - # Reset telemetry system - TelemetryClientFactory._clients.clear() - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False - - yield - - # Cleanup - TelemetryClientFactory._clients.clear() - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._initialized = False - - def test_telemetry_client_concurrent_export_events(self, race_condition_setup): - """Test race conditions in TelemetryClient._export_event with concurrent access.""" - session_id_hex = "test-race-uuid" - auth_provider = MagicMock() - host_url = "test-host" - executor = MagicMock() - - client = TelemetryClient( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=executor, - ) - - # Mock _flush to avoid actual network calls - client._flush = MagicMock() - - # Track events added by each thread - thread_events = {} - lock = threading.Lock() - - def add_events(thread_id): - """Add events from a specific thread.""" - events = [] - for i in range(10): - event = f"event-{thread_id}-{i}" - client._export_event(event) - events.append(event) - - with lock: - thread_events[thread_id] = events - - # Start multiple threads adding events concurrently - threads = [] - for i in range(5): - thread = threading.Thread(target=add_events, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify all events were added (no data loss due to race conditions) - total_expected_events = sum(len(events) for events in thread_events.values()) - assert len(client._events_batch) == total_expected_events - - def test_telemetry_client_concurrent_flush_operations(self, race_condition_setup): - """Test race conditions in TelemetryClient._flush with concurrent access.""" - session_id_hex = "test-flush-race-uuid" - auth_provider = MagicMock() - host_url = "test-host" - executor = MagicMock() - - client = TelemetryClient( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=executor, - ) - - # Mock _send_telemetry to avoid actual network calls - client._send_telemetry = MagicMock() - - # Add events to trigger flush - for i in range(TelemetryClient.DEFAULT_BATCH_SIZE - 1): - client._export_event(f"event-{i}") - - # Track flush operations - flush_count = 0 - flush_lock = threading.Lock() - - def concurrent_flush(): - """Call flush concurrently.""" - nonlocal flush_count - client._flush() - with flush_lock: - flush_count += 1 - - # Start multiple threads calling flush concurrently - threads = [] - for i in range(10): - thread = threading.Thread(target=concurrent_flush) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify flush was called the expected number of times - assert flush_count == 10 - - # Verify _send_telemetry was called at least once (some calls may have empty batches due to lock) - assert client._send_telemetry.call_count >= 1 - - # Verify that the total events processed is correct (no data loss) - # The first flush should have processed all events, subsequent flushes should have empty batches - total_events_sent = sum(len(call.args[0]) for call in client._send_telemetry.call_args_list) - assert total_events_sent == TelemetryClient.DEFAULT_BATCH_SIZE - 1 - - def test_telemetry_client_factory_concurrent_initialization(self, race_condition_setup): - """Test race conditions in TelemetryClientFactory.initialize_telemetry_client with concurrent access.""" - session_id_hex = "test-factory-race-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - # Track initialization attempts - init_results = [] - init_lock = threading.Lock() - - def concurrent_initialize(thread_id): - """Initialize telemetry client concurrently.""" - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - - with init_lock: - init_results.append({ - 'thread_id': thread_id, - 'client_type': type(client).__name__ - }) - - # Start multiple threads initializing concurrently - threads = [] - for i in range(10): - thread = threading.Thread(target=concurrent_initialize, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - assert len(init_results) == 10 - - # Verify only one client was created (no duplicate clients due to race conditions) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, TelemetryClient) - - # Verify the client is the same for all threads (singleton behavior) - client_ids = set() - for result in init_results: - client_ids.add(id(TelemetryClientFactory.get_telemetry_client(session_id_hex))) - - assert len(client_ids) == 1 - - def test_telemetry_client_factory_concurrent_get_client(self, race_condition_setup): - """Test race conditions in TelemetryClientFactory.get_telemetry_client with concurrent access.""" - session_id_hex = "test-get-client-race-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - # Initialize a client first - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - - # Track get_client attempts - get_results = [] - get_lock = threading.Lock() - - def concurrent_get_client(thread_id): - """Get telemetry client concurrently.""" - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - - with get_lock: - get_results.append({ - 'thread_id': thread_id, - 'client_type': type(client).__name__, - 'client_id': id(client) - }) - - # Start multiple threads getting client concurrently - threads = [] - for i in range(20): - thread = threading.Thread(target=concurrent_get_client, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify all get_client calls succeeded - assert len(get_results) == 20 - - # Verify all threads got the same client instance (no race conditions) - client_ids = set(result['client_id'] for result in get_results) - assert len(client_ids) == 1 # Only one client instance returned - - def test_telemetry_client_factory_concurrent_close(self, race_condition_setup): - """Test race conditions in TelemetryClientFactory.close with concurrent access.""" - session_id_hex = "test-close-race-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - # Initialize a client first - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - - def concurrent_close(thread_id): - """Close telemetry client concurrently.""" - - TelemetryClientFactory.close(session_id_hex) - - # Start multiple threads closing concurrently - threads = [] - for i in range(5): - thread = threading.Thread(target=concurrent_close, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify client is no longer available after close - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - - def test_telemetry_client_factory_mixed_concurrent_operations(self, race_condition_setup): - """Test race conditions with mixed concurrent operations on TelemetryClientFactory.""" - session_id_hex = "test-mixed-race-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - # Track operation results - operation_results = [] - operation_lock = threading.Lock() - - def mixed_operations(thread_id): - """Perform mixed operations concurrently.""" - - # Randomly choose an operation - operation = random.choice(['init', 'get', 'close']) - - if operation == 'init': - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - - with operation_lock: - operation_results.append({ - 'thread_id': thread_id, - 'operation': 'init', - 'client_type': type(client).__name__ - }) - - elif operation == 'get': - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - - with operation_lock: - operation_results.append({ - 'thread_id': thread_id, - 'operation': 'get', - 'client_type': type(client).__name__ - }) - - elif operation == 'close': - TelemetryClientFactory.close(session_id_hex) - - with operation_lock: - operation_results.append({ - 'thread_id': thread_id, - 'operation': 'close' - }) - - # Start multiple threads performing mixed operations - threads = [] - for i in range(15): - thread = threading.Thread(target=mixed_operations, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - assert len(operation_results) == 15 + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file From f6f50b2791339cd239bbbb8c2656122b491e8307 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 10:28:04 +0530 Subject: [PATCH 72/86] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 1 - .../sql/telemetry/latency_logger.py | 35 ++++++++++++------- .../sql/telemetry/telemetry_client.py | 7 ++-- tests/unit/test_telemetry.py | 9 ++--- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 50b709e8a..289b25664 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1650,4 +1650,3 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] - diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 0c0f883ce..c6476b362 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -13,16 +13,27 @@ class TelemetryExtractor: def __init__(self, obj): self._obj = obj - + def __getattr__(self, name): return getattr(self._obj, name) - - def get_session_id_hex(self): pass - def get_statement_id(self): pass - def get_statement_type(self): pass - def get_is_compressed(self): pass - def get_execution_result(self): pass - def get_retry_count(self): pass + + def get_session_id_hex(self): + pass + + def get_statement_id(self): + pass + + def get_statement_type(self): + pass + + def get_is_compressed(self): + pass + + def get_execution_result(self): + pass + + def get_retry_count(self): + pass class CursorExtractor(TelemetryExtractor): @@ -58,8 +69,8 @@ def get_retry_count(self) -> int: def get_statement_type(self: str) -> StatementType: # TODO: Implement this return StatementType.SQL - - + + class ResultSetExtractor(TelemetryExtractor): def get_statement_id(self) -> Optional[str]: if self.command_id: @@ -95,9 +106,9 @@ def get_retry_count(self) -> int: def get_extractor(obj): - if obj.__class__.__name__ == 'Cursor': + if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) - elif obj.__class__.__name__ == 'ResultSet': + elif obj.__class__.__name__ == "ResultSet": return ResultSetExtractor(obj) else: return TelemetryExtractor(obj) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 846db448d..f83f0598c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -140,9 +140,7 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): def export_failure_log(self, error_name, error_message): pass - def export_latency_log( - self, latency_ms, sql_execution_event, sql_statement_id=None - ): + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): pass def close(self): @@ -158,7 +156,6 @@ class TelemetryClient(BaseTelemetryClient): # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" - DEFAULT_BATCH_SIZE = 10 def __init__( self, @@ -170,7 +167,7 @@ def __init__( ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = self.DEFAULT_BATCH_SIZE # TODO: Decide on batch size + self._batch_size = 10 # TODO: Decide on batch size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 8707f09de..2c6b24e07 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -229,16 +229,17 @@ def test_export_event(self, telemetry_client_setup): client = telemetry_client_setup["client"] client._flush = MagicMock() - for i in range(TelemetryClient.DEFAULT_BATCH_SIZE-1): + for i in range(5): client._export_event(f"event-{i}") client._flush.assert_not_called() - assert len(client._events_batch) == TelemetryClient.DEFAULT_BATCH_SIZE - 1 + assert len(client._events_batch) == 5 - # Add one more event to reach batch size (this will trigger flush) - client._export_event(f"event-{TelemetryClient.DEFAULT_BATCH_SIZE - 1}") + for i in range(5): + client._export_event(f"event-{i}") client._flush.assert_called_once() + assert len(client._events_batch) == 10 @patch("requests.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): From 1163ebef27c149e881c70fcdbfab9724a47a52a6 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 10:32:22 +0530 Subject: [PATCH 73/86] fixes in test Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 1 - tests/unit/test_telemetry.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f83f0598c..421148aa1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -399,7 +399,6 @@ def initialize_telemetry_client( host_url, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" - try: with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 2c6b24e07..36461e9e9 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -235,7 +235,7 @@ def test_export_event(self, telemetry_client_setup): client._flush.assert_not_called() assert len(client._events_batch) == 5 - for i in range(5): + for i in range(5, 10): client._export_event(f"event-{i}") client._flush.assert_called_once() From 4b6ace030c30f8b4adea90eee7bb10bff0837943 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 10:37:17 +0530 Subject: [PATCH 74/86] fix in telemetry extractor Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/latency_logger.py | 4 ++-- src/databricks/sql/telemetry/telemetry_client.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index c6476b362..a2871f6fe 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -66,7 +66,7 @@ def get_retry_count(self) -> int: return len(self.thrift_backend.retry_policy.history) return 0 - def get_statement_type(self: str) -> StatementType: + def get_statement_type(self) -> StatementType: # TODO: Implement this return StatementType.SQL @@ -92,7 +92,7 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.INLINE_ARROW return ExecutionResultFormat.FORMAT_UNSPECIFIED - def get_statement_type(self: str) -> StatementType: + def get_statement_type(self) -> StatementType: # TODO: Implement this return StatementType.SQL diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 421148aa1..cbdc9f9ff 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -400,6 +400,7 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: + with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() From 4d5614193163e5a1d33da94f7da3de455a991590 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 30 Jun 2025 14:46:23 +0530 Subject: [PATCH 75/86] added doc strings to latency_logger, abstracted export_telemetry_log Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/latency_logger.py | 86 ++++++++++++++++- .../sql/telemetry/telemetry_client.py | 93 ++++++------------- 2 files changed, 115 insertions(+), 64 deletions(-) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index a2871f6fe..3a2e61a84 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -11,10 +11,32 @@ class TelemetryExtractor: + """ + Base class for extracting telemetry information from various object types. + + This class serves as a proxy that delegates attribute access to the wrapped object + while providing a common interface for extracting telemetry-related data. + """ + def __init__(self, obj): + """ + Initialize the extractor with an object to wrap. + + Args: + obj: The object to extract telemetry information from. + """ self._obj = obj def __getattr__(self, name): + """ + Delegate attribute access to the wrapped object. + + Args: + name (str): The name of the attribute to access. + + Returns: + The attribute value from the wrapped object. + """ return getattr(self._obj, name) def get_session_id_hex(self): @@ -37,6 +59,13 @@ def get_retry_count(self): class CursorExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for Cursor objects. + + Extracts telemetry information from database cursor objects, including + statement IDs, session information, compression settings, and result formats. + """ + def get_statement_id(self) -> Optional[str]: return self.query_id @@ -72,6 +101,13 @@ def get_statement_type(self) -> StatementType: class ResultSetExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSet objects. + + Extracts telemetry information from database result set objects, including + operation IDs, session information, compression settings, and result formats. + """ + def get_statement_id(self) -> Optional[str]: if self.command_id: return str(UUID(bytes=self.command_id.operationId.guid)) @@ -106,15 +142,63 @@ def get_retry_count(self) -> int: def get_extractor(obj): + """ + Factory function to create the appropriate telemetry extractor for an object. + + Determines the object type and returns the corresponding specialized extractor + that can extract telemetry information from that object type. + + Args: + obj: The object to create an extractor for. Can be a Cursor, ResultSet, + or any other object. + + Returns: + TelemetryExtractor: A specialized extractor instance: + - CursorExtractor for Cursor objects + - ResultSetExtractor for ResultSet objects + - TelemetryExtractor (base) for all other objects + """ if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) elif obj.__class__.__name__ == "ResultSet": return ResultSetExtractor(obj) else: - return TelemetryExtractor(obj) + raise NotImplementedError(f"No extractor found for {obj.__class__.__name__}") def log_latency(): + """ + Decorator for logging execution latency and telemetry information. + + This decorator measures the execution time of a method and sends telemetry + data about the operation, including latency, statement information, and + execution context. + + The decorator automatically: + - Measures execution time using high-precision performance counters + - Extracts telemetry information from the method's object (self) + - Creates a SqlExecutionEvent with execution details + - Sends the telemetry data asynchronously via TelemetryClient + + Usage: + @log_latency() + def execute(self, query): + # Method implementation + pass + + @log_latency() + def fetchall(self): + # Method implementation + pass + + Returns: + function: A decorator that wraps methods to add latency logging. + + Note: + The wrapped method's object (self) must be compatible with the + telemetry extractor system (e.g., Cursor or ResultSet objects). + """ + def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index cbdc9f9ff..936a07683 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -248,14 +248,24 @@ def _telemetry_request_callback(self, future): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) - def export_initial_telemetry_log(self, driver_connection_params, user_agent): - logger.debug( - "Exporting initial telemetry log for connection %s", self._session_id_hex - ) + def _export_telemetry_log(self, **telemetry_event_kwargs): + """ + Common helper method for exporting telemetry logs. + + Args: + **telemetry_event_kwargs: Keyword arguments to pass to TelemetryEvent constructor + """ + logger.debug("Exporting telemetry log for connection %s", self._session_id_hex) try: - self._driver_connection_params = driver_connection_params - self._user_agent = user_agent + # Set common fields for all telemetry events + event_kwargs = { + "session_id": self._session_id_hex, + "system_configuration": TelemetryHelper.get_driver_system_configuration(), + "driver_connection_params": self._driver_connection_params, + } + # Add any additional fields passed in + event_kwargs.update(telemetry_event_kwargs) telemetry_frontend_log = TelemetryFrontendLog( frontend_log_event_id=str(uuid.uuid4()), @@ -265,72 +275,29 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): user_agent=self._user_agent, ) ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._session_id_hex, - system_configuration=TelemetryHelper.get_driver_system_configuration(), - driver_connection_params=self._driver_connection_params, - ) - ), + entry=FrontendLogEntry(sql_driver_log=TelemetryEvent(**event_kwargs)), ) self._export_event(telemetry_frontend_log) except Exception as e: - logger.debug("Failed to export initial telemetry log: %s", e) + logger.debug("Failed to export telemetry log: %s", e) + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + self._export_telemetry_log() def export_failure_log(self, error_name, error_message): - logger.debug("Exporting failure log for connection %s", self._session_id_hex) - try: - error_info = DriverErrorInfo( - error_name=error_name, stack_trace=error_message - ) - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._session_id_hex, - system_configuration=TelemetryHelper.get_driver_system_configuration(), - driver_connection_params=self._driver_connection_params, - error_info=error_info, - ) - ), - ) - self._export_event(telemetry_frontend_log) - except Exception as e: - logger.debug("Failed to export failure log: %s", e) + error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) + self._export_telemetry_log(error_info=error_info) def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): - logger.debug("Exporting latency log for connection %s", self._session_id_hex) - try: - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._session_id_hex, - system_configuration=TelemetryHelper.get_driver_system_configuration(), - driver_connection_params=self._driver_connection_params, - sql_statement_id=sql_statement_id, - sql_operation=sql_execution_event, - operation_latency_ms=latency_ms, - ) - ), - ) - self._export_event(telemetry_frontend_log) - except Exception as e: - logger.debug("Failed to export latency log: %s", e) + self._export_telemetry_log( + sql_statement_id=sql_statement_id, + sql_operation=sql_execution_event, + operation_latency_ms=latency_ms, + ) def close(self): """Flush remaining events before closing""" From 27295c2c3c2465829272b3157c2fd898192475b8 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 30 Jun 2025 22:22:40 +0530 Subject: [PATCH 76/86] statement type, unit test fix Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 17 +++++---- .../sql/telemetry/latency_logger.py | 25 ++++--------- tests/unit/test_client.py | 35 +++++++++---------- 3 files changed, 34 insertions(+), 43 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 8c5b660d5..89f6b7456 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -62,6 +62,7 @@ HostDetails, ) from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.enums import StatementType logger = logging.getLogger(__name__) @@ -827,7 +828,7 @@ def _handle_staging_remove( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() + @log_latency(StatementType.SQL) def execute( self, operation: str, @@ -918,7 +919,7 @@ def execute( return self - @log_latency() + @log_latency(StatementType.SQL) def execute_async( self, operation: str, @@ -1044,7 +1045,7 @@ def executemany(self, operation, seq_of_parameters): self.execute(operation, parameters) return self - @log_latency() + @log_latency(StatementType.METADATA) def catalogs(self) -> "Cursor": """ Get all available catalogs. @@ -1068,7 +1069,7 @@ def catalogs(self) -> "Cursor": ) return self - @log_latency() + @log_latency(StatementType.METADATA) def schemas( self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None ) -> "Cursor": @@ -1097,7 +1098,7 @@ def schemas( ) return self - @log_latency() + @log_latency(StatementType.METADATA) def tables( self, catalog_name: Optional[str] = None, @@ -1133,7 +1134,7 @@ def tables( ) return self - @log_latency() + @log_latency(StatementType.METADATA) def columns( self, catalog_name: Optional[str] = None, @@ -1444,6 +1445,7 @@ def _convert_arrow_table(self, table): def rownumber(self): return self._next_row_index + @log_latency() def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1486,6 +1488,7 @@ def merge_columnar(self, result1, result2): ] return ColumnTable(merged_result, result1.column_names) + @log_latency() def fetchmany_columnar(self, size: int): """ Fetch the next set of rows of a query result, returning a Columnar Table. @@ -1511,6 +1514,7 @@ def fetchmany_columnar(self, size: int): return results + @log_latency() def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() @@ -1537,6 +1541,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": return pyarrow.Table.from_pydict(data) return results + @log_latency() def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 3a2e61a84..06e6e598d 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -45,9 +45,6 @@ def get_session_id_hex(self): def get_statement_id(self): pass - def get_statement_type(self): - pass - def get_is_compressed(self): pass @@ -95,10 +92,6 @@ def get_retry_count(self) -> int: return len(self.thrift_backend.retry_policy.history) return 0 - def get_statement_type(self) -> StatementType: - # TODO: Implement this - return StatementType.SQL - class ResultSetExtractor(TelemetryExtractor): """ @@ -128,10 +121,6 @@ def get_execution_result(self) -> ExecutionResultFormat: return ExecutionResultFormat.INLINE_ARROW return ExecutionResultFormat.FORMAT_UNSPECIFIED - def get_statement_type(self) -> StatementType: - # TODO: Implement this - return StatementType.SQL - def get_retry_count(self) -> int: if ( hasattr(self.thrift_backend, "retry_policy") @@ -166,7 +155,7 @@ def get_extractor(obj): raise NotImplementedError(f"No extractor found for {obj.__class__.__name__}") -def log_latency(): +def log_latency(statement_type: StatementType = StatementType.NONE): """ Decorator for logging execution latency and telemetry information. @@ -180,17 +169,15 @@ def log_latency(): - Creates a SqlExecutionEvent with execution details - Sends the telemetry data asynchronously via TelemetryClient + Args: + statement_type (StatementType): The type of SQL statement being executed. + Usage: - @log_latency() + @log_latency(StatementType.SQL) def execute(self, query): # Method implementation pass - @log_latency() - def fetchall(self): - # Method implementation - pass - Returns: function: A decorator that wraps methods to add latency logging. @@ -216,7 +203,7 @@ def wrapper(self, *args, **kwargs): statement_id = extractor.get_statement_id() sql_exec_event = SqlExecutionEvent( - statement_type=extractor.get_statement_type(), + statement_type=statement_type, is_compressed=extractor.get_is_compressed(), execution_result=extractor.get_execution_result(), retry_count=extractor.get_retry_count(), diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 12a1f9f06..d418c2bac 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -8,6 +8,19 @@ from datetime import datetime, date from uuid import UUID +def noop_log_latency_decorator(*args, **kwargs): + """ + This is a no-op decorator. It is used to patch the log_latency decorator + during tests, so that the tests for client logic are not affected by the + telemetry logging logic. It accepts any arguments and returns a decorator + that returns the original function unmodified. + """ + def decorator(func): + return func + return decorator + +patch('databricks.sql.telemetry.latency_logger.log_latency', new=noop_log_latency_decorator).start() + from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TExecuteStatementResp, @@ -38,11 +51,6 @@ def new(cls): cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) - # Mock retry_policy with history attribute - mock_retry_policy = Mock() - mock_retry_policy.history = [] - cls.apply_property_to_mock(ThriftBackendMock, retry_policy=mock_retry_policy) - cls.apply_property_to_mock( MockTExecuteStatementResp, description=None, @@ -74,15 +82,6 @@ def apply_property_to_mock(self, mock_obj, **kwargs): prop = PropertyMock(**kwargs) setattr(type(mock_obj), key, prop) - @classmethod - def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator - """Create a simple thrift_backend mock with retry_policy for basic tests.""" - mock_thrift_backend = Mock() - mock_retry_policy = Mock() - mock_retry_policy.history = [] - mock_thrift_backend.retry_policy = mock_retry_policy - return mock_thrift_backend - class ClientTestSuite(unittest.TestCase): """ @@ -332,7 +331,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command( mock_result_sets[1].fetchall.assert_called_once_with() def test_closed_cursor_doesnt_allow_operations(self): - cursor = client.Cursor(Mock(), ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()) + cursor = client.Cursor(Mock(), Mock()) cursor.close() with self.assertRaises(Error) as e: @@ -394,7 +393,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy() + mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.schemas(**req_args) @@ -417,7 +416,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy() + mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.tables(**req_args) @@ -440,7 +439,7 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy() + mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.columns(**req_args) From b558bc8ebab56e1256f3cfd14e951888cd112729 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Mon, 30 Jun 2025 22:26:40 +0530 Subject: [PATCH 77/86] unit test fix Signed-off-by: Sai Shree Pradhan --- tests/unit/test_fetches.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 43af1d361..098764fb9 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -1,12 +1,25 @@ import unittest import pytest -from unittest.mock import Mock +from unittest.mock import Mock, patch try: import pyarrow as pa except ImportError: pa = None +def noop_log_latency_decorator(*args, **kwargs): + """ + This is a no-op decorator. It is used to patch the log_latency decorator + during tests, so that the tests for client logic are not affected by the + telemetry logging logic. It accepts any arguments and returns a decorator + that returns the original function unmodified. + """ + def decorator(func): + return func + return decorator + +patch('databricks.sql.telemetry.latency_logger.log_latency', new=noop_log_latency_decorator).start() + import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue @@ -31,15 +44,6 @@ def make_arrow_queue(batch): _, table = FetchTests.make_arrow_table(batch) queue = ArrowQueue(table, len(batch)) return queue - - @classmethod - def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator - """Create a simple thrift_backend mock with retry_policy for basic tests.""" - mock_thrift_backend = Mock() - mock_retry_policy = Mock() - mock_retry_policy.history = [] - mock_thrift_backend.retry_policy = mock_retry_policy - return mock_thrift_backend @staticmethod def make_dummy_result_set_from_initial_results(initial_results): @@ -48,7 +52,7 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) rs = client.ResultSet( connection=Mock(), - thrift_backend=FetchTests.mock_thrift_backend_with_retry_policy(), + thrift_backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, @@ -88,7 +92,7 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = FetchTests.mock_thrift_backend_with_retry_policy() + mock_thrift_backend = Mock() mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 From 01853bc6a56baa6c383bf2b7726a878bfbf74d15 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 1 Jul 2025 09:47:30 +0530 Subject: [PATCH 78/86] statement type changes Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 89f6b7456..dbf4fa0a2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -746,7 +746,7 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency() + @log_latency(StatementType.SQL) def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None ): @@ -786,7 +786,7 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) - @log_latency() + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None ): @@ -814,7 +814,7 @@ def _handle_staging_get( with open(local_file, "wb") as fp: fp.write(r.content) - @log_latency() + @log_latency(StatementType.SQL) def _handle_staging_remove( self, presigned_url: str, headers: Optional[dict] = None ): @@ -828,7 +828,7 @@ def _handle_staging_remove( session_id_hex=self.connection.get_session_id_hex(), ) - @log_latency(StatementType.SQL) + @log_latency(StatementType.QUERY) def execute( self, operation: str, @@ -919,7 +919,7 @@ def execute( return self - @log_latency(StatementType.SQL) + @log_latency(StatementType.QUERY) def execute_async( self, operation: str, From 45f74d0bd6301801b63cac8b6b044d1017801da7 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 1 Jul 2025 09:50:01 +0530 Subject: [PATCH 79/86] test_fetches fix Signed-off-by: Sai Shree Pradhan --- tests/unit/test_fetches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 098764fb9..3679df85f 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -10,7 +10,7 @@ def noop_log_latency_decorator(*args, **kwargs): """ This is a no-op decorator. It is used to patch the log_latency decorator - during tests, so that the tests for client logic are not affected by the + during tests, so that the tests for fetches logic are not affected by the telemetry logging logic. It accepts any arguments and returns a decorator that returns the original function unmodified. """ From 149d4a8b4995ca97eb71903f2f2677dc9a23d12d Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 1 Jul 2025 10:35:52 +0530 Subject: [PATCH 80/86] added mocks to resolve the errors caused by log_latency decorator in tests Signed-off-by: Sai Shree Pradhan --- tests/unit/test_client.py | 34 +++++++++++++++------------------- tests/unit/test_fetches.py | 23 +++++++++-------------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d418c2bac..c9e6a1d90 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -8,19 +8,6 @@ from datetime import datetime, date from uuid import UUID -def noop_log_latency_decorator(*args, **kwargs): - """ - This is a no-op decorator. It is used to patch the log_latency decorator - during tests, so that the tests for client logic are not affected by the - telemetry logging logic. It accepts any arguments and returns a decorator - that returns the original function unmodified. - """ - def decorator(func): - return func - return decorator - -patch('databricks.sql.telemetry.latency_logger.log_latency', new=noop_log_latency_decorator).start() - from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TExecuteStatementResp, @@ -51,6 +38,10 @@ def new(cls): cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) + mock_retry_policy = Mock() + mock_retry_policy.history = [] + cls.apply_property_to_mock(ThriftBackendMock, retry_policy=mock_retry_policy) + cls.apply_property_to_mock( MockTExecuteStatementResp, description=None, @@ -331,7 +322,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command( mock_result_sets[1].fetchall.assert_called_once_with() def test_closed_cursor_doesnt_allow_operations(self): - cursor = client.Cursor(Mock(), Mock()) + cursor = client.Cursor(Mock(), ThriftBackendMockFactory.new()) cursor.close() with self.assertRaises(Error) as e: @@ -343,14 +334,19 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - result_set = client.ResultSet(Mock(), Mock(), Mock()) + mock_connection = Mock() + mock_connection.get_session_id_hex.return_value = "test_session" + mock_execute_response = Mock() + mock_execute_response.command_handle = None + + result_set = client.ResultSet(mock_connection, mock_execute_response, ThriftBackendMockFactory.new()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) def test_context_manager_closes_cursor(self): mock_close = Mock() - with client.Cursor(Mock(), Mock()) as cursor: + with client.Cursor(Mock(), ThriftBackendMockFactory.new()) as cursor: cursor.close = mock_close mock_close.assert_called_once_with() @@ -393,7 +389,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = Mock() + mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.schemas(**req_args) @@ -416,7 +412,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = Mock() + mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.tables(**req_args) @@ -439,7 +435,7 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = Mock() + mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.columns(**req_args) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 3679df85f..56ec8b6d2 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -7,19 +7,6 @@ except ImportError: pa = None -def noop_log_latency_decorator(*args, **kwargs): - """ - This is a no-op decorator. It is used to patch the log_latency decorator - during tests, so that the tests for fetches logic are not affected by the - telemetry logging logic. It accepts any arguments and returns a decorator - that returns the original function unmodified. - """ - def decorator(func): - return func - return decorator - -patch('databricks.sql.telemetry.latency_logger.log_latency', new=noop_log_latency_decorator).start() - import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue @@ -71,6 +58,14 @@ def make_dummy_result_set_from_initial_results(initial_results): for col_id in range(num_cols) ] return rs + + @staticmethod + def mock_thrift_backend_with_retry_policy(): + mock_thrift_backend = Mock() + mock_retry_policy = Mock() + mock_retry_policy.history = [] + mock_thrift_backend.retry_policy = mock_retry_policy + return mock_thrift_backend @staticmethod def make_dummy_result_set_from_batch_list(batch_list): @@ -92,7 +87,7 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = Mock() + mock_thrift_backend = FetchTests.mock_thrift_backend_with_retry_policy() mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 From e031663ce4518030f81fa3a31610e32e562a2e35 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 1 Jul 2025 14:36:06 +0530 Subject: [PATCH 81/86] removed function in test_fetches cuz it is only used once Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/latency_logger.py | 2 +- tests/unit/test_fetches.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 06e6e598d..8c7272de4 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -145,7 +145,7 @@ def get_extractor(obj): TelemetryExtractor: A specialized extractor instance: - CursorExtractor for Cursor objects - ResultSetExtractor for ResultSet objects - - TelemetryExtractor (base) for all other objects + - Throws an NotImplementedError for all other objects """ if obj.__class__.__name__ == "Cursor": return CursorExtractor(obj) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 56ec8b6d2..b0a47ed15 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -59,14 +59,6 @@ def make_dummy_result_set_from_initial_results(initial_results): ] return rs - @staticmethod - def mock_thrift_backend_with_retry_policy(): - mock_thrift_backend = Mock() - mock_retry_policy = Mock() - mock_retry_policy.history = [] - mock_thrift_backend.retry_policy = mock_retry_policy - return mock_thrift_backend - @staticmethod def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 @@ -87,7 +79,10 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = FetchTests.mock_thrift_backend_with_retry_policy() + mock_thrift_backend = Mock() + mock_retry_policy = Mock() + mock_retry_policy.history = [] + mock_thrift_backend.retry_policy = mock_retry_policy mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 From 142b9a81136149c86e30663b78a9d0fba7a5f9ae Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 2 Jul 2025 15:38:50 +0530 Subject: [PATCH 82/86] added _safe_call which returns None in case of errors in the get functions Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/latency_logger.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 8c7272de4..6180a0af4 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -195,18 +195,26 @@ def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) return result finally: + + def _safe_call(func_to_call): + """Calls a function and returns a default value on any exception.""" + try: + return func_to_call() + except Exception: + return None + end_time = time.perf_counter() duration_ms = int((end_time - start_time) * 1000) extractor = get_extractor(self) - session_id_hex = extractor.get_session_id_hex() - statement_id = extractor.get_statement_id() + session_id_hex = _safe_call(extractor.get_session_id_hex) + statement_id = _safe_call(extractor.get_statement_id) sql_exec_event = SqlExecutionEvent( statement_type=statement_type, - is_compressed=extractor.get_is_compressed(), - execution_result=extractor.get_execution_result(), - retry_count=extractor.get_retry_count(), + is_compressed=_safe_call(extractor.get_is_compressed), + execution_result=_safe_call(extractor.get_execution_result), + retry_count=_safe_call(extractor.get_retry_count), ) telemetry_client = TelemetryClientFactory.get_telemetry_client( From 2a26965d71b48b8944b29775549c0a70c3fc3e98 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 2 Jul 2025 15:49:14 +0530 Subject: [PATCH 83/86] removed the changes in test_client and test_fetches Signed-off-by: Sai Shree Pradhan --- tests/unit/test_client.py | 21 ++++++--------------- tests/unit/test_fetches.py | 3 --- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c9e6a1d90..91e426c64 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -38,10 +38,6 @@ def new(cls): cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp()) - mock_retry_policy = Mock() - mock_retry_policy.history = [] - cls.apply_property_to_mock(ThriftBackendMock, retry_policy=mock_retry_policy) - cls.apply_property_to_mock( MockTExecuteStatementResp, description=None, @@ -322,7 +318,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command( mock_result_sets[1].fetchall.assert_called_once_with() def test_closed_cursor_doesnt_allow_operations(self): - cursor = client.Cursor(Mock(), ThriftBackendMockFactory.new()) + cursor = client.Cursor(Mock(), Mock()) cursor.close() with self.assertRaises(Error) as e: @@ -334,19 +330,14 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_connection = Mock() - mock_connection.get_session_id_hex.return_value = "test_session" - mock_execute_response = Mock() - mock_execute_response.command_handle = None - - result_set = client.ResultSet(mock_connection, mock_execute_response, ThriftBackendMockFactory.new()) + result_set = client.ResultSet(Mock(), Mock(), Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) def test_context_manager_closes_cursor(self): mock_close = Mock() - with client.Cursor(Mock(), ThriftBackendMockFactory.new()) as cursor: + with client.Cursor(Mock(), Mock()) as cursor: cursor.close = mock_close mock_close.assert_called_once_with() @@ -389,7 +380,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.schemas(**req_args) @@ -412,7 +403,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.tables(**req_args) @@ -435,7 +426,7 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} with self.subTest(req_args=req_args): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.columns(**req_args) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index b0a47ed15..357291179 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -80,9 +80,6 @@ def fetch_results( return results, batch_index < len(batch_list) mock_thrift_backend = Mock() - mock_retry_policy = Mock() - mock_retry_policy.history = [] - mock_thrift_backend.retry_policy = mock_retry_policy mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 From ae90deed6ecb88ca32b11c0c279e057784354b21 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 2 Jul 2025 15:50:12 +0530 Subject: [PATCH 84/86] removed the changes in test_fetches Signed-off-by: Sai Shree Pradhan --- tests/unit/test_fetches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 357291179..71766f2cb 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -1,6 +1,6 @@ import unittest import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock try: import pyarrow as pa @@ -58,7 +58,7 @@ def make_dummy_result_set_from_initial_results(initial_results): for col_id in range(num_cols) ] return rs - + @staticmethod def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 From acc99048afc014c26c35d52e39df6a34fb8d3016 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 3 Jul 2025 10:48:13 +0530 Subject: [PATCH 85/86] test_telemetry Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry.py | 630 +++++++++++------------------------ 1 file changed, 188 insertions(+), 442 deletions(-) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 36461e9e9..71c343212 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,10 +1,7 @@ import uuid import pytest import requests -from unittest.mock import patch, MagicMock, call -import threading -import time -import random +from unittest.mock import patch, MagicMock from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -13,15 +10,7 @@ TelemetryHelper, BaseTelemetryClient ) -from databricks.sql.telemetry.models.enums import ( - AuthMech, - DatabricksClientType, - AuthFlow, -) -from databricks.sql.telemetry.models.event import ( - DriverConnectionParameters, - HostDetails, -) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, DatabricksOAuthProvider, @@ -29,48 +18,15 @@ ) -@pytest.fixture -def noop_telemetry_client(): - """Fixture for NoopTelemetryClient.""" - return NoopTelemetryClient() - - -@pytest.fixture -def telemetry_client_setup(): - """Fixture for TelemetryClient setup data.""" - session_id_hex = str(uuid.uuid4()) - auth_provider = AccessTokenAuthProvider("test-token") - host_url = "test-host" - executor = MagicMock() - - client = TelemetryClient( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=executor, - ) - - return { - "client": client, - "session_id_hex": session_id_hex, - "auth_provider": auth_provider, - "host_url": host_url, - "executor": executor, - } - - @pytest.fixture def telemetry_system_reset(): - """Fixture to reset telemetry system state before each test.""" - # Reset the static state before each test + """Reset telemetry system state before each test.""" TelemetryClientFactory._clients.clear() if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False yield - # Cleanup after test if needed TelemetryClientFactory._clients.clear() if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) @@ -78,359 +34,149 @@ def telemetry_system_reset(): TelemetryClientFactory._initialized = False +@pytest.fixture +def mock_telemetry_client(): + """Create a mock telemetry client for testing.""" + session_id = str(uuid.uuid4()) + auth_provider = AccessTokenAuthProvider("test-token") + executor = MagicMock() + + return TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=auth_provider, + host_url="test-host.com", + executor=executor, + ) + + class TestNoopTelemetryClient: - """Tests for the NoopTelemetryClient.""" + """Tests for NoopTelemetryClient - should do nothing safely.""" - def test_singleton(self): - """Test that NoopTelemetryClient is a singleton.""" + def test_noop_client_behavior(self): + """Test that NoopTelemetryClient is a singleton and all methods are safe no-ops.""" + # Test singleton behavior client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - - def test_export_initial_telemetry_log(self, noop_telemetry_client): - """Test that export_initial_telemetry_log does nothing.""" - noop_telemetry_client.export_initial_telemetry_log( - driver_connection_params=MagicMock(), user_agent="test" - ) - - def test_export_failure_log(self, noop_telemetry_client): - """Test that export_failure_log does nothing.""" - noop_telemetry_client.export_failure_log( - error_name="TestError", error_message="Test error message" - ) - - def test_export_latency_log(self, noop_telemetry_client): - """Test that export_latency_log does nothing.""" - noop_telemetry_client.export_latency_log( - latency_ms=100, sql_execution_event="EXECUTE_STATEMENT", sql_statement_id="test-id" - ) - - def test_close(self, noop_telemetry_client): - """Test that close does nothing.""" - noop_telemetry_client.close() - - -class TestTelemetryClient: - """Tests for the TelemetryClient class.""" - - @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") - @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") - @patch("databricks.sql.telemetry.telemetry_client.time.time") - def test_export_initial_telemetry_log( - self, - mock_time, - mock_uuid4, - mock_get_driver_config, - mock_frontend_log, - telemetry_client_setup - ): - """Test exporting initial telemetry log.""" - mock_time.return_value = 1000 - mock_uuid4.return_value = "test-uuid" - mock_get_driver_config.return_value = "test-driver-config" - mock_frontend_log.return_value = MagicMock() - - client = telemetry_client_setup["client"] - host_url = telemetry_client_setup["host_url"] - client._export_event = MagicMock() - - driver_connection_params = DriverConnectionParameters( - http_path="test-path", - mode=DatabricksClientType.THRIFT, - host_info=HostDetails(host_url=host_url, port=443), - auth_mech=AuthMech.PAT, - auth_flow=None, - ) - user_agent = "test-user-agent" - - client.export_initial_telemetry_log(driver_connection_params, user_agent) - - mock_frontend_log.assert_called_once() - client._export_event.assert_called_once_with(mock_frontend_log.return_value) - - @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") - @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") - @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") - @patch("databricks.sql.telemetry.telemetry_client.time.time") - def test_export_failure_log( - self, - mock_time, - mock_uuid4, - mock_driver_error_info, - mock_get_driver_config, - mock_frontend_log, - telemetry_client_setup - ): - """Test exporting failure telemetry log.""" - mock_time.return_value = 2000 - mock_uuid4.return_value = "test-error-uuid" - mock_get_driver_config.return_value = "test-driver-config" - mock_driver_error_info.return_value = MagicMock() - mock_frontend_log.return_value = MagicMock() - - client = telemetry_client_setup["client"] - client._export_event = MagicMock() - - client._driver_connection_params = "test-connection-params" - client._user_agent = "test-user-agent" - - error_name = "TestError" - error_message = "This is a test error message" - - client.export_failure_log(error_name, error_message) - - mock_driver_error_info.assert_called_once_with( - error_name=error_name, - stack_trace=error_message - ) - - mock_frontend_log.assert_called_once() - client._export_event.assert_called_once_with(mock_frontend_log.return_value) + # Test that all methods can be called without exceptions + client1.export_initial_telemetry_log(MagicMock(), "test-agent") + client1.export_failure_log("TestError", "Test message") + client1.export_latency_log(100, "EXECUTE_STATEMENT", "test-id") + client1.close() - @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") - @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") - @patch("databricks.sql.telemetry.telemetry_client.time.time") - def test_export_latency_log( - self, - mock_time, - mock_uuid4, - mock_get_driver_config, - mock_frontend_log, - telemetry_client_setup - ): - """Test exporting latency telemetry log.""" - mock_time.return_value = 3000 - mock_uuid4.return_value = "test-latency-uuid" - mock_get_driver_config.return_value = "test-driver-config" - mock_frontend_log.return_value = MagicMock() - - client = telemetry_client_setup["client"] - client._export_event = MagicMock() - - client._driver_connection_params = "test-connection-params" - client._user_agent = "test-user-agent" - - latency_ms = 150 - sql_execution_event = "test-execution-event" - sql_statement_id = "test-statement-id" - - client.export_latency_log(latency_ms, sql_execution_event, sql_statement_id) - - mock_frontend_log.assert_called_once() - - client._export_event.assert_called_once_with(mock_frontend_log.return_value) - - def test_export_event(self, telemetry_client_setup): - """Test exporting an event.""" - client = telemetry_client_setup["client"] - client._flush = MagicMock() - - for i in range(5): - client._export_event(f"event-{i}") - - client._flush.assert_not_called() - assert len(client._events_batch) == 5 - - for i in range(5, 10): - client._export_event(f"event-{i}") - - client._flush.assert_called_once() - assert len(client._events_batch) == 10 - - @patch("requests.post") - def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): - """Test sending telemetry to the server with authentication.""" - client = telemetry_client_setup["client"] - executor = telemetry_client_setup["executor"] - - events = [MagicMock(), MagicMock()] - events[0].to_json.return_value = '{"event": "1"}' - events[1].to_json.return_value = '{"event": "2"}' - - client._send_telemetry(events) - - executor.submit.assert_called_once() - args, kwargs = executor.submit.call_args - assert args[0] == requests.post - assert kwargs["timeout"] == 10 - assert "Authorization" in kwargs["headers"] - assert kwargs["headers"]["Authorization"] == "Bearer test-token" - @patch("requests.post") - def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup): - """Test sending telemetry to the server without authentication.""" - host_url = telemetry_client_setup["host_url"] - executor = telemetry_client_setup["executor"] +class TestTelemetryClient: + """Tests for actual telemetry client functionality and flows.""" + + def test_event_batching_and_flushing_flow(self, mock_telemetry_client): + """Test the complete event batching and flushing flow.""" + client = mock_telemetry_client + client._batch_size = 3 # Small batch for testing + + # Mock the network call + with patch.object(client, '_send_telemetry') as mock_send: + # Add events one by one - should not flush yet + client._export_event("event1") + client._export_event("event2") + mock_send.assert_not_called() + assert len(client._events_batch) == 2 + + # Third event should trigger flush + client._export_event("event3") + mock_send.assert_called_once() + assert len(client._events_batch) == 0 # Batch cleared after flush + + @patch('requests.post') + def test_network_request_flow(self, mock_post, mock_telemetry_client): + """Test the complete network request flow with authentication.""" + mock_post.return_value.status_code = 200 + client = mock_telemetry_client - unauthenticated_client = TelemetryClient( - telemetry_enabled=True, - session_id_hex=str(uuid.uuid4()), - auth_provider=None, # No auth provider - host_url=host_url, - executor=executor, - ) + # Create mock events + mock_events = [MagicMock() for _ in range(2)] + for i, event in enumerate(mock_events): + event.to_json.return_value = f'{{"event": "{i}"}}' - events = [MagicMock(), MagicMock()] - events[0].to_json.return_value = '{"event": "1"}' - events[1].to_json.return_value = '{"event": "2"}' + # Send telemetry + client._send_telemetry(mock_events) - unauthenticated_client._send_telemetry(events) + # Verify request was submitted to executor + client._executor.submit.assert_called_once() + args, kwargs = client._executor.submit.call_args - executor.submit.assert_called_once() - args, kwargs = executor.submit.call_args + # Verify correct function and URL assert args[0] == requests.post - assert kwargs["timeout"] == 10 - assert "Authorization" not in kwargs["headers"] # No auth header - assert kwargs["headers"]["Accept"] == "application/json" - assert kwargs["headers"]["Content-Type"] == "application/json" - - def test_flush(self, telemetry_client_setup): - """Test flushing events.""" - client = telemetry_client_setup["client"] - client._events_batch = ["event1", "event2"] - client._send_telemetry = MagicMock() - - client._flush() - - client._send_telemetry.assert_called_once_with(["event1", "event2"]) - assert client._events_batch == [] - - def test_close(self, telemetry_client_setup): - """Test closing the client.""" - client = telemetry_client_setup["client"] - client._flush = MagicMock() + assert args[1] == 'https://test-host.com/telemetry-ext' + assert kwargs['headers']['Authorization'] == 'Bearer test-token' + assert kwargs['timeout'] == 10 - client.close() - - client._flush.assert_called_once() - - @patch("requests.post") - def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup): - """Test successful telemetry request callback.""" - client = telemetry_client_setup["client"] - - mock_response = MagicMock() - mock_response.status_code = 200 - - mock_future = MagicMock() - mock_future.result.return_value = mock_response - - client._telemetry_request_callback(mock_future) - - mock_future.result.assert_called_once() - - @patch("requests.post") - def test_telemetry_request_callback_failure(self, mock_post, telemetry_client_setup): - """Test telemetry request callback with failure""" - client = telemetry_client_setup["client"] - - # Test with non-200 status code - mock_response = MagicMock() - mock_response.status_code = 500 - future = MagicMock() - future.result.return_value = mock_response - client._telemetry_request_callback(future) - - # Test with exception - future = MagicMock() - future.result.side_effect = Exception("Test error") - client._telemetry_request_callback(future) + # Verify request body structure + request_data = kwargs['data'] + assert '"uploadTime"' in request_data + assert '"protoLogs"' in request_data - def test_telemetry_client_exception_handling(self, telemetry_client_setup): - """Test exception handling in telemetry client methods.""" - client = telemetry_client_setup["client"] + def test_telemetry_logging_flows(self, mock_telemetry_client): + """Test all telemetry logging methods work end-to-end.""" + client = mock_telemetry_client - # Test export_initial_telemetry_log with exception - with patch.object(client, '_export_event', side_effect=Exception("Test error")): - # Should not raise exception + with patch.object(client, '_export_event') as mock_export: + # Test initial log client.export_initial_telemetry_log(MagicMock(), "test-agent") - - # Test export_failure_log with exception + assert mock_export.call_count == 1 + + # Test failure log + client.export_failure_log("TestError", "Error message") + assert mock_export.call_count == 2 + + # Test latency log + client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") + assert mock_export.call_count == 3 + + def test_error_handling_resilience(self, mock_telemetry_client): + """Test that telemetry errors don't break the client.""" + client = mock_telemetry_client + + # Test that exceptions in telemetry don't propagate with patch.object(client, '_export_event', side_effect=Exception("Test error")): - # Should not raise exception - client.export_failure_log("TestError", "Test error message") - - # Test export_latency_log with exception - with patch.object(client, '_export_event', side_effect=Exception("Test error")): - # Should not raise exception - client.export_latency_log(100, "EXECUTE_STATEMENT", "test-statement-id") + # These should not raise exceptions + client.export_initial_telemetry_log(MagicMock(), "test-agent") + client.export_failure_log("TestError", "Error message") + client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") - # Test _send_telemetry with exception - with patch.object(client._executor, 'submit', side_effect=Exception("Test error")): - # Should not raise exception - client._send_telemetry([MagicMock()]) - - def test_send_telemetry_thread_pool_failure(self, telemetry_client_setup): - """Test handling of thread pool submission failure""" - client = telemetry_client_setup["client"] - client._executor.submit.side_effect = Exception("Thread pool error") - event = MagicMock() - client._send_telemetry([event]) - - def test_base_telemetry_client_abstract_methods(self): - """Test that BaseTelemetryClient cannot be instantiated without implementing all abstract methods""" - class TestBaseClient(BaseTelemetryClient): - pass - - with pytest.raises(TypeError): - TestBaseClient() # Can't instantiate abstract class + # Test executor submission failure + client._executor.submit.side_effect = Exception("Thread pool error") + client._send_telemetry([MagicMock()]) # Should not raise class TestTelemetryHelper: - """Tests for the TelemetryHelper class.""" + """Tests for TelemetryHelper utility functions.""" - def test_get_driver_system_configuration(self): - """Test getting driver system configuration.""" - config = TelemetryHelper.get_driver_system_configuration() - - assert isinstance(config.driver_name, str) - assert isinstance(config.driver_version, str) - assert isinstance(config.runtime_name, str) - assert isinstance(config.runtime_vendor, str) - assert isinstance(config.runtime_version, str) - assert isinstance(config.os_name, str) - assert isinstance(config.os_version, str) - assert isinstance(config.os_arch, str) - assert isinstance(config.locale_name, str) - assert isinstance(config.char_set_encoding, str) - - assert config.driver_name == "Databricks SQL Python Connector" - assert "Python" in config.runtime_name - assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] - assert config.os_name in ["Darwin", "Linux", "Windows"] - - # Verify caching behavior + def test_system_configuration_caching(self): + """Test that system configuration is cached and contains expected data.""" + config1 = TelemetryHelper.get_driver_system_configuration() config2 = TelemetryHelper.get_driver_system_configuration() - assert config is config2 # Should return same instance - - def test_get_auth_mechanism(self): - """Test getting auth mechanism for different auth providers.""" - # Test PAT auth - pat_auth = AccessTokenAuthProvider("test-token") - assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT - - # Test OAuth auth - oauth_auth = MagicMock(spec=DatabricksOAuthProvider) - assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH - # Test External auth - external_auth = MagicMock(spec=ExternalAuthProvider) - assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH - - # Test None auth provider - assert TelemetryHelper.get_auth_mechanism(None) is None - - # Test unknown auth provider - unknown_auth = MagicMock() - assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT - - def test_get_auth_flow(self): - """Test getting auth flow for different OAuth providers.""" - # Test OAuth with existing tokens + # Should be cached (same instance) + assert config1 is config2 + + def test_auth_mechanism_detection(self): + """Test authentication mechanism detection for different providers.""" + test_cases = [ + (AccessTokenAuthProvider("token"), AuthMech.PAT), + (MagicMock(spec=DatabricksOAuthProvider), AuthMech.DATABRICKS_OAUTH), + (MagicMock(spec=ExternalAuthProvider), AuthMech.EXTERNAL_AUTH), + (MagicMock(), AuthMech.CLIENT_CERT), # Unknown provider + (None, None), + ] + + for provider, expected in test_cases: + assert TelemetryHelper.get_auth_mechanism(provider) == expected + + def test_auth_flow_detection(self): + """Test authentication flow detection for OAuth providers.""" + # OAuth with existing tokens oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) oauth_with_tokens._access_token = "test-access-token" oauth_with_tokens._refresh_token = "test-refresh-token" @@ -451,99 +197,99 @@ def test_get_auth_flow(self): assert TelemetryHelper.get_auth_flow(None) is None -class TestTelemetrySystem: - """Tests for the telemetry system functions.""" +class TestTelemetryFactory: + """Tests for TelemetryClientFactory lifecycle and management.""" - def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): - """Test initializing a telemetry client when telemetry is enabled.""" - session_id_hex = "test-uuid" - auth_provider = MagicMock() - host_url = "test-host" + def test_client_lifecycle_flow(self, telemetry_system_reset): + """Test complete client lifecycle: initialize -> use -> close.""" + session_id_hex = "test-session" + auth_provider = AccessTokenAuthProvider("token") + # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, - host_url=host_url, + host_url="test-host.com" ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex - assert client._auth_provider == auth_provider - assert client._host_url == host_url - - def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): - """Test initializing a telemetry client when telemetry is disabled.""" - session_id_hex = "test-uuid" - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=False, - session_id_hex=session_id_hex, - auth_provider=MagicMock(), - host_url="test-host", - ) + # Close client + with patch.object(client, 'close') as mock_close: + TelemetryClientFactory.close(session_id_hex) + mock_close.assert_called_once() + + # Should get NoopTelemetryClient after close client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): - """Test getting a non-existent telemetry client.""" - client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") - assert isinstance(client, NoopTelemetryClient) - - def test_close_telemetry_client(self, telemetry_system_reset): - """Test closing a telemetry client.""" - session_id_hex = "test-uuid" - auth_provider = MagicMock() - host_url = "test-host" + def test_disabled_telemetry_flow(self, telemetry_system_reset): + """Test that disabled telemetry uses NoopTelemetryClient.""" + session_id_hex = "test-session" TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, + telemetry_enabled=False, session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, + auth_provider=None, + host_url="test-host.com" ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, TelemetryClient) - - client.close = MagicMock() - - TelemetryClientFactory.close(session_id_hex) - - client.close.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_close_telemetry_client_noop(self, telemetry_system_reset): - """Test closing a no-op telemetry client.""" - session_id_hex = "test-uuid" - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=False, - session_id_hex=session_id_hex, - auth_provider=MagicMock(), - host_url="test-host", - ) - - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + def test_factory_error_handling(self, telemetry_system_reset): + """Test that factory errors fall back to NoopTelemetryClient.""" + session_id = "test-session" + + # Simulate initialization error + with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', + side_effect=Exception("Init error")): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=AccessTokenAuthProvider("token"), + host_url="test-host.com" + ) + + # Should fall back to NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) + + def test_factory_shutdown_flow(self, telemetry_system_reset): + """Test factory shutdown when last client is removed.""" + session1 = "session-1" + session2 = "session-2" - client.close = MagicMock() + # Initialize multiple clients + for session in [session1, session2]: + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session, + auth_provider=AccessTokenAuthProvider("token"), + host_url="test-host.com" + ) - TelemetryClientFactory.close(session_id_hex) + # Factory should be initialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None - client.close.assert_called_once() + # Close first client - factory should stay initialized + TelemetryClientFactory.close(session1) + assert TelemetryClientFactory._initialized is True - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) + # Close second client - factory should shut down + TelemetryClientFactory.close(session2) + assert TelemetryClientFactory._initialized is False + assert TelemetryClientFactory._executor is None - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory._handle_unhandled_exception") - def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): - """Test that global exception hook is installed and handles exceptions.""" - TelemetryClientFactory._install_exception_hook() - - test_exception = ValueError("Test exception") - TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file + +class TestAbstractInterface: + """Test that abstract base class works correctly.""" + + def test_base_client_cannot_be_instantiated(self): + """Test that BaseTelemetryClient cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseTelemetryClient() \ No newline at end of file From a84712267182019b3af6de08144febe771ac7de9 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 3 Jul 2025 12:24:52 +0530 Subject: [PATCH 86/86] removed test Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 71c343212..271e84970 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -283,13 +283,4 @@ def test_factory_shutdown_flow(self, telemetry_system_reset): # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None - - -class TestAbstractInterface: - """Test that abstract base class works correctly.""" - - def test_base_client_cannot_be_instantiated(self): - """Test that BaseTelemetryClient cannot be instantiated directly.""" - with pytest.raises(TypeError): - BaseTelemetryClient() \ No newline at end of file + assert TelemetryClientFactory._executor is None \ No newline at end of file