diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 936a0768..85dc4f66 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -402,7 +402,7 @@ def get_telemetry_client(session_id_hex): if session_id_hex in TelemetryClientFactory._clients: return TelemetryClientFactory._clients[session_id_hex] else: - logger.error( + logger.debug( "Telemetry client not initialized for connection %s", session_id_hex, ) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py new file mode 100644 index 00000000..bb148f23 --- /dev/null +++ b/tests/e2e/test_concurrent_telemetry.py @@ -0,0 +1,174 @@ +import threading +from unittest.mock import patch, MagicMock + +from databricks.sql.client import Connection +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory, TelemetryClient +from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.utils import ExecuteResponse +from databricks.sql.thrift_api.TCLIService.ttypes import TSessionHandle, TOperationHandle, TOperationState, THandleIdentifier + +try: + import pyarrow as pa +except ImportError: + pa = None + + +def run_in_threads(target, num_threads, pass_index=False): + """Helper to run target function in multiple threads.""" + threads = [ + threading.Thread(target=target, args=(i,) if pass_index else ()) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + +class MockArrowQueue: + """Mock queue that behaves like ArrowQueue but returns empty results.""" + + def __init__(self): + # Create an empty arrow table if pyarrow is available, otherwise use None + if pa is not None: + self.empty_table = pa.table({'column': pa.array([])}) + else: + # Create a simple mock table-like object + self.empty_table = MagicMock() + self.empty_table.num_rows = 0 + self.empty_table.num_columns = 0 + + def next_n_rows(self, num_rows: int): + """Return empty results.""" + return self.empty_table + + def remaining_rows(self): + """Return empty results.""" + return self.empty_table + + +def test_concurrent_queries_with_telemetry_capture(): + """ + Test showing concurrent threads executing queries with real telemetry capture. + Uses the actual Connection and Cursor classes, mocking only the ThriftBackend. + """ + num_threads = 5 + captured_telemetry = [] + connections = [] # Store connections to close them later + connections_lock = threading.Lock() # Thread safety for connections list + + def mock_send_telemetry(self, events): + """Capture telemetry events instead of sending them over network.""" + captured_telemetry.extend(events) + + # Clean up any existing state + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + with patch.object(TelemetryClient, '_send_telemetry', mock_send_telemetry): + # Mock the ThriftBackend to avoid actual network calls + with patch.object(ThriftBackend, 'open_session') as mock_open_session, \ + patch.object(ThriftBackend, 'execute_command') as mock_execute_command, \ + patch.object(ThriftBackend, 'close_session') as mock_close_session, \ + patch.object(ThriftBackend, 'fetch_results') as mock_fetch_results, \ + patch.object(ThriftBackend, 'close_command') as mock_close_command, \ + patch.object(ThriftBackend, 'handle_to_hex_id') as mock_handle_to_hex_id, \ + patch('databricks.sql.auth.thrift_http_client.THttpClient.open') as mock_transport_open: + + # Mock transport.open() to prevent actual network connection + mock_transport_open.return_value = None + + # Set up mock responses with proper structure + mock_handle_identifier = THandleIdentifier() + mock_handle_identifier.guid = b'1234567890abcdef' + mock_handle_identifier.secret = b'test_secret_1234' + + mock_session_handle = TSessionHandle() + mock_session_handle.sessionId = mock_handle_identifier + mock_session_handle.serverProtocolVersion = 1 + + mock_open_session.return_value = MagicMock( + sessionHandle=mock_session_handle, + serverProtocolVersion=1 + ) + + mock_handle_to_hex_id.return_value = "test-session-id-12345678" + + mock_op_handle = TOperationHandle() + mock_op_handle.operationId = THandleIdentifier() + mock_op_handle.operationId.guid = b'abcdef1234567890' + mock_op_handle.operationId.secret = b'op_secret_abcd' + + # Create proper mock arrow_queue with required methods + mock_arrow_queue = MockArrowQueue() + + mock_execute_response = ExecuteResponse( + arrow_queue=mock_arrow_queue, + description=[], + command_handle=mock_op_handle, + status=TOperationState.FINISHED_STATE, + has_been_closed_server_side=False, + has_more_rows=False, + lz4_compressed=False, + arrow_schema_bytes=b'', + is_staging_operation=False + ) + mock_execute_command.return_value = mock_execute_response + + # Mock fetch_results to return empty results + mock_fetch_results.return_value = (mock_arrow_queue, False) + + # Mock close_command to do nothing + mock_close_command.return_value = None + + # Mock close_session to do nothing + mock_close_session.return_value = None + + def execute_query_worker(thread_id): + """Each thread creates a connection and executes a query.""" + + # Create real Connection and Cursor objects + conn = Connection( + server_hostname="test-host", + http_path="/test/path", + access_token="test-token", + enable_telemetry=True + ) + + # Thread-safe storage of connection + with connections_lock: + connections.append(conn) + + cursor = conn.cursor() + # This will trigger the @log_latency decorator naturally + cursor.execute(f"SELECT {thread_id} as thread_id") + result = cursor.fetchall() + conn.close() + + + run_in_threads(execute_query_worker, num_threads, pass_index=True) + + # We expect at least 2 events per thread (one for open_session and one for execute_command) + assert len(captured_telemetry) >= num_threads*2 + print(f"Captured telemetry: {captured_telemetry}") + + # Verify the decorator was used (check some telemetry events have latency measurement) + events_with_latency = [ + e for e in captured_telemetry + if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log') + and e.entry.sql_driver_log.operation_latency_ms is not None + ] + assert len(events_with_latency) >= num_threads + + # Verify we have events with statement IDs (indicating @log_latency decorator worked) + events_with_statements = [ + e for e in captured_telemetry + if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log') + and e.entry.sql_driver_log.sql_statement_id is not None + ] + assert len(events_with_statements) >= num_threads + + \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 271e8497..c485c555 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,13 +2,16 @@ import pytest import requests from unittest.mock import patch, MagicMock +import threading +import random +import time +from concurrent.futures import ThreadPoolExecutor from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset): # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None + + +# A helper function to run a target in multiple threads and wait for them. +def run_in_threads(target, num_threads, pass_index=False): + """Creates, starts, and joins a specified number of threads. + + Args: + target: The function to run in each thread + num_threads: Number of threads to create + pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument + """ + threads = [ + threading.Thread(target=target, args=(i,) if pass_index else ()) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + +class TestTelemetryRaceConditions: + """Tests for race conditions in multithreaded scenarios.""" + + @pytest.fixture(autouse=True) + def clean_factory(self): + """A fixture to automatically reset the factory's state before each test.""" + # Clean up at the start of each test + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield + + # Clean up at the end of each test + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_factory_concurrent_initialization_of_DIFFERENT_clients(self): + """ + Tests that multiple threads creating DIFFERENT clients concurrently + share a single ThreadPoolExecutor and all clients are created successfully. + """ + num_threads = 20 + + def create_client(thread_id): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=f"session_{thread_id}", + auth_provider=None, + host_url="test-host", + ) + + run_in_threads(create_client, 20, pass_index=True) + + # ASSERT: The factory was properly initialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None + assert isinstance(TelemetryClientFactory._executor, ThreadPoolExecutor) + + # ASSERT: All clients were successfully created + assert len(TelemetryClientFactory._clients) == num_threads + + # ASSERT: All TelemetryClient instances share the same executor + telemetry_clients = [ + client for client in TelemetryClientFactory._clients.values() + if isinstance(client, TelemetryClient) + ] + assert len(telemetry_clients) == num_threads + + shared_executor = TelemetryClientFactory._executor + for client in telemetry_clients: + assert client._executor is shared_executor + + def test_factory_concurrent_initialization_of_SAME_client(self): + """ + Tests that multiple threads trying to initialize the SAME client + result in only one client instance being created. + """ + session_id = "shared-session" + num_threads = 20 + + def create_same_client(): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test-host", + ) + + run_in_threads(create_same_client, num_threads) + + # ASSERT: Only one client was created in the factory. + assert len(TelemetryClientFactory._clients) == 1 + client = TelemetryClientFactory.get_telemetry_client(session_id) + assert isinstance(client, TelemetryClient) + + def test_client_concurrent_event_export(self): + """ + Tests that no events are lost when multiple threads call _export_event + on the same client instance concurrently. + """ + client = TelemetryClient(True, "session-1", None, "host", MagicMock()) + # Mock _flush to prevent auto-flushing when batch size threshold is reached + original_flush = client._flush + client._flush = MagicMock() + + num_threads = 5 + events_per_thread = 10 + + def add_events(): + for i in range(events_per_thread): + client._export_event(f"event-{i}") + + run_in_threads(add_events, num_threads) + + # ASSERT: The batch contains all events from all threads, none were lost. + total_expected_events = num_threads * events_per_thread + assert len(client._events_batch) == total_expected_events + + # Restore original flush method for cleanup + client._flush = original_flush + + def test_client_concurrent_flush(self): + """ + Tests that if multiple threads trigger _flush at the same time, + the underlying send operation is only called once for the batch. + """ + client = TelemetryClient(True, "session-1", None, "host", MagicMock()) + client._send_telemetry = MagicMock() + + # Pre-fill the batch so there's something to flush + client._events_batch = ["event"] * 5 + + def call_flush(): + client._flush() + + run_in_threads(call_flush, 10) + + # ASSERT: The send operation was called exactly once. + # This proves the lock prevents multiple threads from sending the same batch. + client._send_telemetry.assert_called_once() + # ASSERT: The event batch is now empty. + assert len(client._events_batch) == 0 + + def test_factory_concurrent_create_and_close(self): + """ + Tests that concurrently creating and closing different clients + doesn't corrupt the factory state and correctly shuts down the executor. + """ + num_ops = 50 + + def create_and_close_client(i): + session_id = f"session_{i}" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="host" + ) + # Small sleep to increase chance of interleaving operations + time.sleep(random.uniform(0, 0.01)) + TelemetryClientFactory.close(session_id) + + run_in_threads(create_and_close_client, num_ops, pass_index=True) + + # ASSERT: After all operations, the factory should be empty and reset. + assert not TelemetryClientFactory._clients + assert TelemetryClientFactory._executor is None + assert not TelemetryClientFactory._initialized \ No newline at end of file