diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c137306a..dbf4fa0a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -61,7 +61,8 @@ DriverConnectionParameters, HostDetails, ) - +from databricks.sql.telemetry.latency_logger import log_latency +from databricks.sql.telemetry.models.enums import StatementType logger = logging.getLogger(__name__) @@ -745,6 +746,7 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency(StatementType.SQL) def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None ): @@ -784,6 +786,7 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None ): @@ -811,6 +814,7 @@ def _handle_staging_get( with open(local_file, "wb") as fp: fp.write(r.content) + @log_latency(StatementType.SQL) def _handle_staging_remove( self, presigned_url: str, headers: Optional[dict] = None ): @@ -824,6 +828,7 @@ def _handle_staging_remove( session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency(StatementType.QUERY) def execute( self, operation: str, @@ -914,6 +919,7 @@ def execute( return self + @log_latency(StatementType.QUERY) def execute_async( self, operation: str, @@ -1039,6 +1045,7 @@ def executemany(self, operation, seq_of_parameters): self.execute(operation, parameters) return self + @log_latency(StatementType.METADATA) def catalogs(self) -> "Cursor": """ Get all available catalogs. @@ -1062,6 +1069,7 @@ def catalogs(self) -> "Cursor": ) return self + @log_latency(StatementType.METADATA) def schemas( self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None ) -> "Cursor": @@ -1090,6 +1098,7 @@ def schemas( ) return self + @log_latency(StatementType.METADATA) def tables( self, catalog_name: Optional[str] = None, @@ -1125,6 +1134,7 @@ def tables( ) return self + @log_latency(StatementType.METADATA) def columns( self, catalog_name: Optional[str] = None, @@ -1379,6 +1389,7 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows + @log_latency() def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) @@ -1391,6 +1402,7 @@ def _convert_columnar_table(self, table): return result + @log_latency() def _convert_arrow_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) @@ -1433,6 +1445,7 @@ def _convert_arrow_table(self, table): def rownumber(self): return self._next_row_index + @log_latency() def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1475,6 +1488,7 @@ def merge_columnar(self, result1, result2): ] return ColumnTable(merged_result, result1.column_names) + @log_latency() def fetchmany_columnar(self, size: int): """ Fetch the next set of rows of a query result, returning a Columnar Table. @@ -1500,6 +1514,7 @@ def fetchmany_columnar(self, size: int): return results + @log_latency() def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() @@ -1526,6 +1541,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": return pyarrow.Table.from_pydict(data) return results + @log_latency() def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py new file mode 100644 index 00000000..6180a0af --- /dev/null +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -0,0 +1,231 @@ +import time +import functools +from typing import Optional +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.models.event import ( + SqlExecutionEvent, +) +from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType +from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue +from uuid import UUID + + +class TelemetryExtractor: + """ + Base class for extracting telemetry information from various object types. + + This class serves as a proxy that delegates attribute access to the wrapped object + while providing a common interface for extracting telemetry-related data. + """ + + def __init__(self, obj): + """ + Initialize the extractor with an object to wrap. + + Args: + obj: The object to extract telemetry information from. + """ + self._obj = obj + + def __getattr__(self, name): + """ + Delegate attribute access to the wrapped object. + + Args: + name (str): The name of the attribute to access. + + Returns: + The attribute value from the wrapped object. + """ + return getattr(self._obj, name) + + def get_session_id_hex(self): + pass + + def get_statement_id(self): + pass + + def get_is_compressed(self): + pass + + def get_execution_result(self): + pass + + def get_retry_count(self): + pass + + +class CursorExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for Cursor objects. + + Extracts telemetry information from database cursor objects, including + statement IDs, session information, compression settings, and result formats. + """ + + def get_statement_id(self) -> Optional[str]: + return self.query_id + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.connection.lz4_compression + + def get_execution_result(self) -> ExecutionResultFormat: + if self.active_result_set is None: + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + if isinstance(self.active_result_set.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.active_result_set.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.active_result_set.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) + return 0 + + +class ResultSetExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSet objects. + + Extracts telemetry information from database result set objects, including + operation IDs, session information, compression settings, and result formats. + """ + + def get_statement_id(self) -> Optional[str]: + if self.command_id: + return str(UUID(bytes=self.command_id.operationId.guid)) + return None + + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() + + def get_is_compressed(self) -> bool: + return self.lz4_compressed + + def get_execution_result(self) -> ExecutionResultFormat: + if isinstance(self.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + if ( + hasattr(self.thrift_backend, "retry_policy") + and self.thrift_backend.retry_policy + ): + return len(self.thrift_backend.retry_policy.history) + return 0 + + +def get_extractor(obj): + """ + Factory function to create the appropriate telemetry extractor for an object. + + Determines the object type and returns the corresponding specialized extractor + that can extract telemetry information from that object type. + + Args: + obj: The object to create an extractor for. Can be a Cursor, ResultSet, + or any other object. + + Returns: + TelemetryExtractor: A specialized extractor instance: + - CursorExtractor for Cursor objects + - ResultSetExtractor for ResultSet objects + - Throws an NotImplementedError for all other objects + """ + if obj.__class__.__name__ == "Cursor": + return CursorExtractor(obj) + elif obj.__class__.__name__ == "ResultSet": + return ResultSetExtractor(obj) + else: + raise NotImplementedError(f"No extractor found for {obj.__class__.__name__}") + + +def log_latency(statement_type: StatementType = StatementType.NONE): + """ + Decorator for logging execution latency and telemetry information. + + This decorator measures the execution time of a method and sends telemetry + data about the operation, including latency, statement information, and + execution context. + + The decorator automatically: + - Measures execution time using high-precision performance counters + - Extracts telemetry information from the method's object (self) + - Creates a SqlExecutionEvent with execution details + - Sends the telemetry data asynchronously via TelemetryClient + + Args: + statement_type (StatementType): The type of SQL statement being executed. + + Usage: + @log_latency(StatementType.SQL) + def execute(self, query): + # Method implementation + pass + + Returns: + function: A decorator that wraps methods to add latency logging. + + Note: + The wrapped method's object (self) must be compatible with the + telemetry extractor system (e.g., Cursor or ResultSet objects). + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + start_time = time.perf_counter() + result = None + try: + result = func(self, *args, **kwargs) + return result + finally: + + def _safe_call(func_to_call): + """Calls a function and returns a default value on any exception.""" + try: + return func_to_call() + except Exception: + return None + + end_time = time.perf_counter() + duration_ms = int((end_time - start_time) * 1000) + + extractor = get_extractor(self) + session_id_hex = _safe_call(extractor.get_session_id_hex) + statement_id = _safe_call(extractor.get_statement_id) + + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=_safe_call(extractor.get_is_compressed), + execution_result=_safe_call(extractor.get_execution_result), + retry_count=_safe_call(extractor.get_retry_count), + ) + + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=statement_id, + ) + + return wrapper + + return decorator diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..936a0768 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -4,7 +4,7 @@ import requests import logging from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional +from typing import Dict, Optional, List from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -112,6 +112,10 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): def export_failure_log(self, error_name, error_message): raise NotImplementedError("Subclasses must implement export_failure_log") + @abstractmethod + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + raise NotImplementedError("Subclasses must implement export_latency_log") + @abstractmethod def close(self): raise NotImplementedError("Subclasses must implement close") @@ -136,6 +140,9 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): def export_failure_log(self, error_name, error_message): pass + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + pass + def close(self): pass @@ -241,14 +248,24 @@ def _telemetry_request_callback(self, future): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) - def export_initial_telemetry_log(self, driver_connection_params, user_agent): - logger.debug( - "Exporting initial telemetry log for connection %s", self._session_id_hex - ) + def _export_telemetry_log(self, **telemetry_event_kwargs): + """ + Common helper method for exporting telemetry logs. + + Args: + **telemetry_event_kwargs: Keyword arguments to pass to TelemetryEvent constructor + """ + logger.debug("Exporting telemetry log for connection %s", self._session_id_hex) try: - self._driver_connection_params = driver_connection_params - self._user_agent = user_agent + # Set common fields for all telemetry events + event_kwargs = { + "session_id": self._session_id_hex, + "system_configuration": TelemetryHelper.get_driver_system_configuration(), + "driver_connection_params": self._driver_connection_params, + } + # Add any additional fields passed in + event_kwargs.update(telemetry_event_kwargs) telemetry_frontend_log = TelemetryFrontendLog( frontend_log_event_id=str(uuid.uuid4()), @@ -258,46 +275,29 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent): user_agent=self._user_agent, ) ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._session_id_hex, - system_configuration=TelemetryHelper.get_driver_system_configuration(), - driver_connection_params=self._driver_connection_params, - ) - ), + entry=FrontendLogEntry(sql_driver_log=TelemetryEvent(**event_kwargs)), ) self._export_event(telemetry_frontend_log) except Exception as e: - logger.debug("Failed to export initial telemetry log: %s", e) + logger.debug("Failed to export telemetry log: %s", e) + + def export_initial_telemetry_log(self, driver_connection_params, user_agent): + self._driver_connection_params = driver_connection_params + self._user_agent = user_agent + self._export_telemetry_log() def export_failure_log(self, error_name, error_message): - logger.debug("Exporting failure log for connection %s", self._session_id_hex) - try: - error_info = DriverErrorInfo( - error_name=error_name, stack_trace=error_message - ) - telemetry_frontend_log = TelemetryFrontendLog( - frontend_log_event_id=str(uuid.uuid4()), - context=FrontendLogContext( - client_context=TelemetryClientContext( - timestamp_millis=int(time.time() * 1000), - user_agent=self._user_agent, - ) - ), - entry=FrontendLogEntry( - sql_driver_log=TelemetryEvent( - session_id=self._session_id_hex, - system_configuration=TelemetryHelper.get_driver_system_configuration(), - driver_connection_params=self._driver_connection_params, - error_info=error_info, - ) - ), - ) - self._export_event(telemetry_frontend_log) - except Exception as e: - logger.debug("Failed to export failure log: %s", e) + error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) + self._export_telemetry_log(error_info=error_info) + + def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + self._export_telemetry_log( + sql_statement_id=sql_statement_id, + sql_operation=sql_execution_event, + operation_latency_ms=latency_ms, + ) def close(self): """Flush remaining events before closing""" @@ -431,6 +431,9 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - TelemetryClientFactory._executor.shutdown(wait=True) + try: + TelemetryClientFactory._executor.shutdown(wait=True) + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..271e8497 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,7 +1,7 @@ import uuid import pytest import requests -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -10,15 +10,7 @@ TelemetryHelper, BaseTelemetryClient ) -from databricks.sql.telemetry.models.enums import ( - AuthMech, - DatabricksClientType, - AuthFlow, -) -from databricks.sql.telemetry.models.event import ( - DriverConnectionParameters, - HostDetails, -) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, DatabricksOAuthProvider, @@ -26,48 +18,15 @@ ) -@pytest.fixture -def noop_telemetry_client(): - """Fixture for NoopTelemetryClient.""" - return NoopTelemetryClient() - - -@pytest.fixture -def telemetry_client_setup(): - """Fixture for TelemetryClient setup data.""" - session_id_hex = str(uuid.uuid4()) - auth_provider = AccessTokenAuthProvider("test-token") - host_url = "test-host" - executor = MagicMock() - - client = TelemetryClient( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=executor, - ) - - return { - "client": client, - "session_id_hex": session_id_hex, - "auth_provider": auth_provider, - "host_url": host_url, - "executor": executor, - } - - @pytest.fixture def telemetry_system_reset(): - """Fixture to reset telemetry system state before each test.""" - # Reset the static state before each test + """Reset telemetry system state before each test.""" TelemetryClientFactory._clients.clear() if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False yield - # Cleanup after test if needed TelemetryClientFactory._clients.clear() if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) @@ -75,314 +34,149 @@ def telemetry_system_reset(): TelemetryClientFactory._initialized = False +@pytest.fixture +def mock_telemetry_client(): + """Create a mock telemetry client for testing.""" + session_id = str(uuid.uuid4()) + auth_provider = AccessTokenAuthProvider("test-token") + executor = MagicMock() + + return TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=auth_provider, + host_url="test-host.com", + executor=executor, + ) + + class TestNoopTelemetryClient: - """Tests for the NoopTelemetryClient.""" + """Tests for NoopTelemetryClient - should do nothing safely.""" - def test_singleton(self): - """Test that NoopTelemetryClient is a singleton.""" + def test_noop_client_behavior(self): + """Test that NoopTelemetryClient is a singleton and all methods are safe no-ops.""" + # Test singleton behavior client1 = NoopTelemetryClient() client2 = NoopTelemetryClient() assert client1 is client2 - - def test_export_initial_telemetry_log(self, noop_telemetry_client): - """Test that export_initial_telemetry_log does nothing.""" - noop_telemetry_client.export_initial_telemetry_log( - driver_connection_params=MagicMock(), user_agent="test" - ) - - def test_export_failure_log(self, noop_telemetry_client): - """Test that export_failure_log does nothing.""" - noop_telemetry_client.export_failure_log( - error_name="TestError", error_message="Test error message" - ) - - def test_close(self, noop_telemetry_client): - """Test that close does nothing.""" - noop_telemetry_client.close() - - -class TestTelemetryClient: - """Tests for the TelemetryClient class.""" - - @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") - @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") - @patch("databricks.sql.telemetry.telemetry_client.time.time") - def test_export_initial_telemetry_log( - self, - mock_time, - mock_uuid4, - mock_get_driver_config, - mock_frontend_log, - telemetry_client_setup - ): - """Test exporting initial telemetry log.""" - mock_time.return_value = 1000 - mock_uuid4.return_value = "test-uuid" - mock_get_driver_config.return_value = "test-driver-config" - mock_frontend_log.return_value = MagicMock() - - client = telemetry_client_setup["client"] - host_url = telemetry_client_setup["host_url"] - client._export_event = MagicMock() - - driver_connection_params = DriverConnectionParameters( - http_path="test-path", - mode=DatabricksClientType.THRIFT, - host_info=HostDetails(host_url=host_url, port=443), - auth_mech=AuthMech.PAT, - auth_flow=None, - ) - user_agent = "test-user-agent" - - client.export_initial_telemetry_log(driver_connection_params, user_agent) - mock_frontend_log.assert_called_once() - client._export_event.assert_called_once_with(mock_frontend_log.return_value) + # Test that all methods can be called without exceptions + client1.export_initial_telemetry_log(MagicMock(), "test-agent") + client1.export_failure_log("TestError", "Test message") + client1.export_latency_log(100, "EXECUTE_STATEMENT", "test-id") + client1.close() - @patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog") - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration") - @patch("databricks.sql.telemetry.telemetry_client.DriverErrorInfo") - @patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4") - @patch("databricks.sql.telemetry.telemetry_client.time.time") - def test_export_failure_log( - self, - mock_time, - mock_uuid4, - mock_driver_error_info, - mock_get_driver_config, - mock_frontend_log, - telemetry_client_setup - ): - """Test exporting failure telemetry log.""" - mock_time.return_value = 2000 - mock_uuid4.return_value = "test-error-uuid" - mock_get_driver_config.return_value = "test-driver-config" - mock_driver_error_info.return_value = MagicMock() - mock_frontend_log.return_value = MagicMock() - client = telemetry_client_setup["client"] - client._export_event = MagicMock() - - client._driver_connection_params = "test-connection-params" - client._user_agent = "test-user-agent" - - error_name = "TestError" - error_message = "This is a test error message" - - client.export_failure_log(error_name, error_message) - - mock_driver_error_info.assert_called_once_with( - error_name=error_name, - stack_trace=error_message - ) - - mock_frontend_log.assert_called_once() - - client._export_event.assert_called_once_with(mock_frontend_log.return_value) - - def test_export_event(self, telemetry_client_setup): - """Test exporting an event.""" - client = telemetry_client_setup["client"] - client._flush = MagicMock() - - for i in range(5): - client._export_event(f"event-{i}") - - client._flush.assert_not_called() - assert len(client._events_batch) == 5 - - for i in range(5, 10): - client._export_event(f"event-{i}") +class TestTelemetryClient: + """Tests for actual telemetry client functionality and flows.""" + + def test_event_batching_and_flushing_flow(self, mock_telemetry_client): + """Test the complete event batching and flushing flow.""" + client = mock_telemetry_client + client._batch_size = 3 # Small batch for testing + + # Mock the network call + with patch.object(client, '_send_telemetry') as mock_send: + # Add events one by one - should not flush yet + client._export_event("event1") + client._export_event("event2") + mock_send.assert_not_called() + assert len(client._events_batch) == 2 + + # Third event should trigger flush + client._export_event("event3") + mock_send.assert_called_once() + assert len(client._events_batch) == 0 # Batch cleared after flush + + @patch('requests.post') + def test_network_request_flow(self, mock_post, mock_telemetry_client): + """Test the complete network request flow with authentication.""" + mock_post.return_value.status_code = 200 + client = mock_telemetry_client - client._flush.assert_called_once() - assert len(client._events_batch) == 10 - - @patch("requests.post") - def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): - """Test sending telemetry to the server with authentication.""" - client = telemetry_client_setup["client"] - executor = telemetry_client_setup["executor"] + # Create mock events + mock_events = [MagicMock() for _ in range(2)] + for i, event in enumerate(mock_events): + event.to_json.return_value = f'{{"event": "{i}"}}' - events = [MagicMock(), MagicMock()] - events[0].to_json.return_value = '{"event": "1"}' - events[1].to_json.return_value = '{"event": "2"}' + # Send telemetry + client._send_telemetry(mock_events) - client._send_telemetry(events) + # Verify request was submitted to executor + client._executor.submit.assert_called_once() + args, kwargs = client._executor.submit.call_args - executor.submit.assert_called_once() - args, kwargs = executor.submit.call_args + # Verify correct function and URL assert args[0] == requests.post - assert kwargs["timeout"] == 10 - assert "Authorization" in kwargs["headers"] - assert kwargs["headers"]["Authorization"] == "Bearer test-token" - - @patch("requests.post") - def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup): - """Test sending telemetry to the server without authentication.""" - host_url = telemetry_client_setup["host_url"] - executor = telemetry_client_setup["executor"] - - unauthenticated_client = TelemetryClient( - telemetry_enabled=True, - session_id_hex=str(uuid.uuid4()), - auth_provider=None, # No auth provider - host_url=host_url, - executor=executor, - ) + assert args[1] == 'https://test-host.com/telemetry-ext' + assert kwargs['headers']['Authorization'] == 'Bearer test-token' + assert kwargs['timeout'] == 10 - events = [MagicMock(), MagicMock()] - events[0].to_json.return_value = '{"event": "1"}' - events[1].to_json.return_value = '{"event": "2"}' - - unauthenticated_client._send_telemetry(events) - - executor.submit.assert_called_once() - args, kwargs = executor.submit.call_args - assert args[0] == requests.post - assert kwargs["timeout"] == 10 - assert "Authorization" not in kwargs["headers"] # No auth header - assert kwargs["headers"]["Accept"] == "application/json" - assert kwargs["headers"]["Content-Type"] == "application/json" + # Verify request body structure + request_data = kwargs['data'] + assert '"uploadTime"' in request_data + assert '"protoLogs"' in request_data - def test_flush(self, telemetry_client_setup): - """Test flushing events.""" - client = telemetry_client_setup["client"] - client._events_batch = ["event1", "event2"] - client._send_telemetry = MagicMock() + def test_telemetry_logging_flows(self, mock_telemetry_client): + """Test all telemetry logging methods work end-to-end.""" + client = mock_telemetry_client - client._flush() - - client._send_telemetry.assert_called_once_with(["event1", "event2"]) - assert client._events_batch == [] - - def test_close(self, telemetry_client_setup): - """Test closing the client.""" - client = telemetry_client_setup["client"] - client._flush = MagicMock() - - client.close() - - client._flush.assert_called_once() - - @patch("requests.post") - def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup): - """Test successful telemetry request callback.""" - client = telemetry_client_setup["client"] - - mock_response = MagicMock() - mock_response.status_code = 200 - - mock_future = MagicMock() - mock_future.result.return_value = mock_response - - client._telemetry_request_callback(mock_future) - - mock_future.result.assert_called_once() - - @patch("requests.post") - def test_telemetry_request_callback_failure(self, mock_post, telemetry_client_setup): - """Test telemetry request callback with failure""" - client = telemetry_client_setup["client"] - - # Test with non-200 status code - mock_response = MagicMock() - mock_response.status_code = 500 - future = MagicMock() - future.result.return_value = mock_response - client._telemetry_request_callback(future) - - # Test with exception - future = MagicMock() - future.result.side_effect = Exception("Test error") - client._telemetry_request_callback(future) - - def test_telemetry_client_exception_handling(self, telemetry_client_setup): - """Test exception handling in telemetry client methods.""" - client = telemetry_client_setup["client"] - - # Test export_initial_telemetry_log with exception - with patch.object(client, '_export_event', side_effect=Exception("Test error")): - # Should not raise exception + with patch.object(client, '_export_event') as mock_export: + # Test initial log client.export_initial_telemetry_log(MagicMock(), "test-agent") - - # Test export_failure_log with exception + assert mock_export.call_count == 1 + + # Test failure log + client.export_failure_log("TestError", "Error message") + assert mock_export.call_count == 2 + + # Test latency log + client.export_latency_log(150, "EXECUTE_STATEMENT", "stmt-123") + assert mock_export.call_count == 3 + + def test_error_handling_resilience(self, mock_telemetry_client): + """Test that telemetry errors don't break the client.""" + client = mock_telemetry_client + + # Test that exceptions in telemetry don't propagate with patch.object(client, '_export_event', side_effect=Exception("Test error")): - # Should not raise exception - client.export_failure_log("TestError", "Test error message") + # These should not raise exceptions + client.export_initial_telemetry_log(MagicMock(), "test-agent") + client.export_failure_log("TestError", "Error message") + client.export_latency_log(100, "EXECUTE_STATEMENT", "stmt-123") - # Test _send_telemetry with exception - with patch.object(client._executor, 'submit', side_effect=Exception("Test error")): - # Should not raise exception - client._send_telemetry([MagicMock()]) - - def test_send_telemetry_thread_pool_failure(self, telemetry_client_setup): - """Test handling of thread pool submission failure""" - client = telemetry_client_setup["client"] - client._executor.submit.side_effect = Exception("Thread pool error") - event = MagicMock() - client._send_telemetry([event]) - - def test_base_telemetry_client_abstract_methods(self): - """Test that BaseTelemetryClient cannot be instantiated without implementing all abstract methods""" - class TestBaseClient(BaseTelemetryClient): - pass - - with pytest.raises(TypeError): - TestBaseClient() # Can't instantiate abstract class + # Test executor submission failure + client._executor.submit.side_effect = Exception("Thread pool error") + client._send_telemetry([MagicMock()]) # Should not raise class TestTelemetryHelper: - """Tests for the TelemetryHelper class.""" + """Tests for TelemetryHelper utility functions.""" - def test_get_driver_system_configuration(self): - """Test getting driver system configuration.""" - config = TelemetryHelper.get_driver_system_configuration() - - assert isinstance(config.driver_name, str) - assert isinstance(config.driver_version, str) - assert isinstance(config.runtime_name, str) - assert isinstance(config.runtime_vendor, str) - assert isinstance(config.runtime_version, str) - assert isinstance(config.os_name, str) - assert isinstance(config.os_version, str) - assert isinstance(config.os_arch, str) - assert isinstance(config.locale_name, str) - assert isinstance(config.char_set_encoding, str) - - assert config.driver_name == "Databricks SQL Python Connector" - assert "Python" in config.runtime_name - assert config.runtime_vendor in ["CPython", "PyPy", "Jython", "IronPython"] - assert config.os_name in ["Darwin", "Linux", "Windows"] - - # Verify caching behavior + def test_system_configuration_caching(self): + """Test that system configuration is cached and contains expected data.""" + config1 = TelemetryHelper.get_driver_system_configuration() config2 = TelemetryHelper.get_driver_system_configuration() - assert config is config2 # Should return same instance - - def test_get_auth_mechanism(self): - """Test getting auth mechanism for different auth providers.""" - # Test PAT auth - pat_auth = AccessTokenAuthProvider("test-token") - assert TelemetryHelper.get_auth_mechanism(pat_auth) == AuthMech.PAT - - # Test OAuth auth - oauth_auth = MagicMock(spec=DatabricksOAuthProvider) - assert TelemetryHelper.get_auth_mechanism(oauth_auth) == AuthMech.DATABRICKS_OAUTH - # Test External auth - external_auth = MagicMock(spec=ExternalAuthProvider) - assert TelemetryHelper.get_auth_mechanism(external_auth) == AuthMech.EXTERNAL_AUTH - - # Test None auth provider - assert TelemetryHelper.get_auth_mechanism(None) is None - - # Test unknown auth provider - unknown_auth = MagicMock() - assert TelemetryHelper.get_auth_mechanism(unknown_auth) == AuthMech.CLIENT_CERT - - def test_get_auth_flow(self): - """Test getting auth flow for different OAuth providers.""" - # Test OAuth with existing tokens + # Should be cached (same instance) + assert config1 is config2 + + def test_auth_mechanism_detection(self): + """Test authentication mechanism detection for different providers.""" + test_cases = [ + (AccessTokenAuthProvider("token"), AuthMech.PAT), + (MagicMock(spec=DatabricksOAuthProvider), AuthMech.DATABRICKS_OAUTH), + (MagicMock(spec=ExternalAuthProvider), AuthMech.EXTERNAL_AUTH), + (MagicMock(), AuthMech.CLIENT_CERT), # Unknown provider + (None, None), + ] + + for provider, expected in test_cases: + assert TelemetryHelper.get_auth_mechanism(provider) == expected + + def test_auth_flow_detection(self): + """Test authentication flow detection for OAuth providers.""" + # OAuth with existing tokens oauth_with_tokens = MagicMock(spec=DatabricksOAuthProvider) oauth_with_tokens._access_token = "test-access-token" oauth_with_tokens._refresh_token = "test-refresh-token" @@ -403,99 +197,90 @@ def test_get_auth_flow(self): assert TelemetryHelper.get_auth_flow(None) is None -class TestTelemetrySystem: - """Tests for the telemetry system functions.""" +class TestTelemetryFactory: + """Tests for TelemetryClientFactory lifecycle and management.""" - def test_initialize_telemetry_client_enabled(self, telemetry_system_reset): - """Test initializing a telemetry client when telemetry is enabled.""" - session_id_hex = "test-uuid" - auth_provider = MagicMock() - host_url = "test-host" + def test_client_lifecycle_flow(self, telemetry_system_reset): + """Test complete client lifecycle: initialize -> use -> close.""" + session_id_hex = "test-session" + auth_provider = AccessTokenAuthProvider("token") + # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, auth_provider=auth_provider, - host_url=host_url, + host_url="test-host.com" ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex - assert client._auth_provider == auth_provider - assert client._host_url == host_url - - def test_initialize_telemetry_client_disabled(self, telemetry_system_reset): - """Test initializing a telemetry client when telemetry is disabled.""" - session_id_hex = "test-uuid" - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=False, - session_id_hex=session_id_hex, - auth_provider=MagicMock(), - host_url="test-host", - ) - - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - - def test_get_telemetry_client_nonexistent(self, telemetry_system_reset): - """Test getting a non-existent telemetry client.""" - client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid") - assert isinstance(client, NoopTelemetryClient) - - def test_close_telemetry_client(self, telemetry_system_reset): - """Test closing a telemetry client.""" - session_id_hex = "test-uuid" - auth_provider = MagicMock() - host_url = "test-host" - - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, TelemetryClient) - - client.close = MagicMock() - - TelemetryClientFactory.close(session_id_hex) - - client.close.assert_called_once() + # Close client + with patch.object(client, 'close') as mock_close: + TelemetryClientFactory.close(session_id_hex) + mock_close.assert_called_once() + # Should get NoopTelemetryClient after close client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_close_telemetry_client_noop(self, telemetry_system_reset): - """Test closing a no-op telemetry client.""" - session_id_hex = "test-uuid" + def test_disabled_telemetry_flow(self, telemetry_system_reset): + """Test that disabled telemetry uses NoopTelemetryClient.""" + session_id_hex = "test-session" + TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, session_id_hex=session_id_hex, - auth_provider=MagicMock(), - host_url="test-host", + auth_provider=None, + host_url="test-host.com" ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - - client.close = MagicMock() - - TelemetryClientFactory.close(session_id_hex) - - client.close.assert_called_once() - - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + def test_factory_error_handling(self, telemetry_system_reset): + """Test that factory errors fall back to NoopTelemetryClient.""" + session_id = "test-session" + + # Simulate initialization error + with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', + side_effect=Exception("Init error")): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=AccessTokenAuthProvider("token"), + host_url="test-host.com" + ) + + # Should fall back to NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory._handle_unhandled_exception") - def test_global_exception_hook(self, mock_handle_exception, telemetry_system_reset): - """Test that global exception hook is installed and handles exceptions.""" - TelemetryClientFactory._install_exception_hook() - - test_exception = ValueError("Test exception") - TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file + def test_factory_shutdown_flow(self, telemetry_system_reset): + """Test factory shutdown when last client is removed.""" + session1 = "session-1" + session2 = "session-2" + + # Initialize multiple clients + for session in [session1, session2]: + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session, + auth_provider=AccessTokenAuthProvider("token"), + host_url="test-host.com" + ) + + # Factory should be initialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None + + # Close first client - factory should stay initialized + TelemetryClientFactory.close(session1) + assert TelemetryClientFactory._initialized is True + + # Close second client - factory should shut down + TelemetryClientFactory.close(session2) + assert TelemetryClientFactory._initialized is False + assert TelemetryClientFactory._executor is None \ No newline at end of file