From 67a8497861f1ed58cbc501f5af83b4a471da4284 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 10:49:28 +0530 Subject: [PATCH 1/2] added multithreaded tests, exeception handling tests Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 8 +- tests/unit/test_telemetry.py | 408 +++++++++++++++++- 2 files changed, 407 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..db9299ab 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -149,6 +149,7 @@ class TelemetryClient(BaseTelemetryClient): # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" + DEFAULT_BATCH_SIZE = 10 def __init__( self, @@ -160,7 +161,7 @@ def __init__( ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = 10 # TODO: Decide on batch size + self._batch_size = self.DEFAULT_BATCH_SIZE # TODO: Decide on batch size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -431,6 +432,9 @@ def close(session_id_hex): logger.debug( "No more telemetry clients, shutting down thread pool executor" ) - TelemetryClientFactory._executor.shutdown(wait=True) + try: + TelemetryClientFactory._executor.shutdown(wait=True) + except Exception as e: + logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..d1611909 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,7 +1,9 @@ import uuid import pytest import requests -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock +import threading +import random from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -186,17 +188,16 @@ def test_export_event(self, telemetry_client_setup): client = telemetry_client_setup["client"] client._flush = MagicMock() - for i in range(5): + for i in range(TelemetryClient.DEFAULT_BATCH_SIZE-1): client._export_event(f"event-{i}") client._flush.assert_not_called() - assert len(client._events_batch) == 5 + assert len(client._events_batch) == TelemetryClient.DEFAULT_BATCH_SIZE - 1 - for i in range(5, 10): - client._export_event(f"event-{i}") + # Add one more event to reach batch size (this will trigger flush) + client._export_event(f"event-{TelemetryClient.DEFAULT_BATCH_SIZE - 1}") client._flush.assert_called_once() - assert len(client._events_batch) == 10 @patch("requests.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): @@ -498,4 +499,397 @@ def test_global_exception_hook(self, mock_handle_exception, telemetry_system_res test_exception = ValueError("Test exception") TelemetryClientFactory._handle_unhandled_exception(type(test_exception), test_exception, None) - mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) \ No newline at end of file + mock_handle_exception.assert_called_once_with(type(test_exception), test_exception, None) + + def test_initialize_telemetry_client_exception_handling(self, telemetry_system_reset): + """Test that exceptions in initialize_telemetry_client don't cause connector to fail.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Test exception during TelemetryClient creation + with patch('databricks.sql.telemetry.telemetry_client.TelemetryClient', side_effect=Exception("TelemetryClient creation failed")): + # Should not raise exception, should fallback to NoopTelemetryClient + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_get_telemetry_client_exception_handling(self, telemetry_system_reset): + """Test that exceptions in get_telemetry_client don't cause connector to fail.""" + session_id_hex = "test-uuid" + + # Test exception during client lookup by mocking the clients dict + mock_clients = MagicMock() + mock_clients.__contains__.side_effect = Exception("Client lookup failed") + + with patch.object(TelemetryClientFactory, '_clients', mock_clients): + # Should not raise exception, should return NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_get_telemetry_client_dict_access_exception(self, telemetry_system_reset): + """Test that exceptions during dictionary access don't cause connector to fail.""" + session_id_hex = "test-uuid" + + # Test exception during dictionary access + mock_clients = MagicMock() + mock_clients.__contains__.side_effect = Exception("Dictionary access failed") + TelemetryClientFactory._clients = mock_clients + + # Should not raise exception, should return NoopTelemetryClient + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_close_telemetry_client_shutdown_executor_exception(self, telemetry_system_reset): + """Test that exceptions during executor shutdown don't cause connector to fail.""" + session_id_hex = "test-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Mock executor to raise exception during shutdown + mock_executor = MagicMock() + mock_executor.shutdown.side_effect = Exception("Executor shutdown failed") + TelemetryClientFactory._executor = mock_executor + + # Should not raise exception (executor shutdown is wrapped in try-catch) + TelemetryClientFactory.close(session_id_hex) + + # Verify executor shutdown was attempted + mock_executor.shutdown.assert_called_once_with(wait=True) + + + +class TestTelemetryRaceConditions: + """Tests for race conditions in multithreaded scenarios.""" + + @pytest.fixture + def race_condition_setup(self): + """Setup for race condition tests.""" + # Reset telemetry system + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield + + # Cleanup + TelemetryClientFactory._clients.clear() + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_telemetry_client_concurrent_export_events(self, race_condition_setup): + """Test race conditions in TelemetryClient._export_event with concurrent access.""" + session_id_hex = "test-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + # Mock _flush to avoid actual network calls + client._flush = MagicMock() + + # Track events added by each thread + thread_events = {} + lock = threading.Lock() + + def add_events(thread_id): + """Add events from a specific thread.""" + events = [] + for i in range(10): + event = f"event-{thread_id}-{i}" + client._export_event(event) + events.append(event) + + with lock: + thread_events[thread_id] = events + + # Start multiple threads adding events concurrently + threads = [] + for i in range(5): + thread = threading.Thread(target=add_events, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all events were added (no data loss due to race conditions) + total_expected_events = sum(len(events) for events in thread_events.values()) + assert len(client._events_batch) == total_expected_events + + def test_telemetry_client_concurrent_flush_operations(self, race_condition_setup): + """Test race conditions in TelemetryClient._flush with concurrent access.""" + session_id_hex = "test-flush-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + executor = MagicMock() + + client = TelemetryClient( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=executor, + ) + + # Mock _send_telemetry to avoid actual network calls + client._send_telemetry = MagicMock() + + # Add events to trigger flush + for i in range(TelemetryClient.DEFAULT_BATCH_SIZE - 1): + client._export_event(f"event-{i}") + + # Track flush operations + flush_count = 0 + flush_lock = threading.Lock() + + def concurrent_flush(): + """Call flush concurrently.""" + nonlocal flush_count + client._flush() + with flush_lock: + flush_count += 1 + + # Start multiple threads calling flush concurrently + threads = [] + for i in range(10): + thread = threading.Thread(target=concurrent_flush) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify flush was called the expected number of times + assert flush_count == 10 + + # Verify _send_telemetry was called at least once (some calls may have empty batches due to lock) + assert client._send_telemetry.call_count >= 1 + + # Verify that the total events processed is correct (no data loss) + # The first flush should have processed all events, subsequent flushes should have empty batches + total_events_sent = sum(len(call.args[0]) for call in client._send_telemetry.call_args_list) + assert total_events_sent == TelemetryClient.DEFAULT_BATCH_SIZE - 1 + + def test_telemetry_client_factory_concurrent_initialization(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.initialize_telemetry_client with concurrent access.""" + session_id_hex = "test-factory-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Track initialization attempts + init_results = [] + init_lock = threading.Lock() + + def concurrent_initialize(thread_id): + """Initialize telemetry client concurrently.""" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with init_lock: + init_results.append({ + 'thread_id': thread_id, + 'client_type': type(client).__name__ + }) + + # Start multiple threads initializing concurrently + threads = [] + for i in range(10): + thread = threading.Thread(target=concurrent_initialize, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + assert len(init_results) == 10 + + # Verify only one client was created (no duplicate clients due to race conditions) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, TelemetryClient) + + # Verify the client is the same for all threads (singleton behavior) + client_ids = set() + for result in init_results: + client_ids.add(id(TelemetryClientFactory.get_telemetry_client(session_id_hex))) + + assert len(client_ids) == 1 + + def test_telemetry_client_factory_concurrent_get_client(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.get_telemetry_client with concurrent access.""" + session_id_hex = "test-get-client-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + # Track get_client attempts + get_results = [] + get_lock = threading.Lock() + + def concurrent_get_client(thread_id): + """Get telemetry client concurrently.""" + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with get_lock: + get_results.append({ + 'thread_id': thread_id, + 'client_type': type(client).__name__, + 'client_id': id(client) + }) + + # Start multiple threads getting client concurrently + threads = [] + for i in range(20): + thread = threading.Thread(target=concurrent_get_client, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all get_client calls succeeded + assert len(get_results) == 20 + + # Verify all threads got the same client instance (no race conditions) + client_ids = set(result['client_id'] for result in get_results) + assert len(client_ids) == 1 # Only one client instance returned + + def test_telemetry_client_factory_concurrent_close(self, race_condition_setup): + """Test race conditions in TelemetryClientFactory.close with concurrent access.""" + session_id_hex = "test-close-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Initialize a client first + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + + def concurrent_close(thread_id): + """Close telemetry client concurrently.""" + + TelemetryClientFactory.close(session_id_hex) + + # Start multiple threads closing concurrently + threads = [] + for i in range(5): + thread = threading.Thread(target=concurrent_close, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify client is no longer available after close + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + assert isinstance(client, NoopTelemetryClient) + + def test_telemetry_client_factory_mixed_concurrent_operations(self, race_condition_setup): + """Test race conditions with mixed concurrent operations on TelemetryClientFactory.""" + session_id_hex = "test-mixed-race-uuid" + auth_provider = MagicMock() + host_url = "test-host" + + # Track operation results + operation_results = [] + operation_lock = threading.Lock() + + def mixed_operations(thread_id): + """Perform mixed operations concurrently.""" + + # Randomly choose an operation + operation = random.choice(['init', 'get', 'close']) + + if operation == 'init': + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + ) + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'init', + 'client_type': type(client).__name__ + }) + + elif operation == 'get': + client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'get', + 'client_type': type(client).__name__ + }) + + elif operation == 'close': + TelemetryClientFactory.close(session_id_hex) + + with operation_lock: + operation_results.append({ + 'thread_id': thread_id, + 'operation': 'close' + }) + + # Start multiple threads performing mixed operations + threads = [] + for i in range(15): + thread = threading.Thread(target=mixed_operations, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + assert len(operation_results) == 15 From 70fd810270dfd7db7924e76bfe0f84b9f6299b34 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Tue, 1 Jul 2025 14:09:46 +0530 Subject: [PATCH 2/2] used batch size instead of default batch size Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d1611909..519f79f1 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -188,14 +188,15 @@ def test_export_event(self, telemetry_client_setup): client = telemetry_client_setup["client"] client._flush = MagicMock() - for i in range(TelemetryClient.DEFAULT_BATCH_SIZE-1): + batch_size = client._batch_size + + for i in range(batch_size - 1): client._export_event(f"event-{i}") client._flush.assert_not_called() - assert len(client._events_batch) == TelemetryClient.DEFAULT_BATCH_SIZE - 1 - - # Add one more event to reach batch size (this will trigger flush) - client._export_event(f"event-{TelemetryClient.DEFAULT_BATCH_SIZE - 1}") + assert len(client._events_batch) == batch_size - 1 + + client._export_event(f"event-{batch_size - 1}") client._flush.assert_called_once() @@ -658,11 +659,9 @@ def test_telemetry_client_concurrent_flush_operations(self, race_condition_setup executor=executor, ) - # Mock _send_telemetry to avoid actual network calls client._send_telemetry = MagicMock() - # Add events to trigger flush - for i in range(TelemetryClient.DEFAULT_BATCH_SIZE - 1): + for i in range(client._batch_size - 1): client._export_event(f"event-{i}") # Track flush operations @@ -690,13 +689,13 @@ def concurrent_flush(): # Verify flush was called the expected number of times assert flush_count == 10 - # Verify _send_telemetry was called at least once (some calls may have empty batches due to lock) - assert client._send_telemetry.call_count >= 1 + # Verify _send_telemetry was called once + assert client._send_telemetry.call_count == 1 # Verify that the total events processed is correct (no data loss) # The first flush should have processed all events, subsequent flushes should have empty batches total_events_sent = sum(len(call.args[0]) for call in client._send_telemetry.call_args_list) - assert total_events_sent == TelemetryClient.DEFAULT_BATCH_SIZE - 1 + assert total_events_sent == client._batch_size - 1 def test_telemetry_client_factory_concurrent_initialization(self, race_condition_setup): """Test race conditions in TelemetryClientFactory.initialize_telemetry_client with concurrent access."""