From 4b6e331248dd697795621e56e97c887df55a1bf6 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 11:55:40 +0530 Subject: [PATCH 1/7] retry Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 12 +++--- .../sql/telemetry/telemetry_client.py | 40 ++++++++++++++++++- tests/unit/test_telemetry.py | 36 +++++++++++++++-- 3 files changed, 76 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 30fd6c26..fbdbe6a5 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,8 +1,6 @@ import json import logging -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - logger = logging.getLogger(__name__) ### PEP-249 Mandated ### @@ -21,11 +19,11 @@ def __init__( self.context = context or {} error_name = self.__class__.__name__ - if session_id_hex: - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_failure_log(error_name, self.message) + + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory + + telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..bcea4ab1 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -22,6 +22,8 @@ DatabricksOAuthProvider, ExternalAuthProvider, ) +from requests.adapters import HTTPAdapter +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType import sys import platform import uuid @@ -31,6 +33,24 @@ logger = logging.getLogger(__name__) +class TelemetryHTTPAdapter(HTTPAdapter): + """ + Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. + This ensures the retry timer is started and the command type is set correctly, + allowing the policy to manage its state for the duration of the request retries. + """ + + def send(self, request, **kwargs): + # The DatabricksRetryPolicy needs state set before the first attempt. + if isinstance(self.max_retries, DatabricksRetryPolicy): + # Telemetry requests are idempotent and safe to retry. We use CommandType.OTHER + # to signal this to the retry policy, bypassing stricter rules for commands + # like ExecuteStatement. + self.max_retries.command_type = CommandType.OTHER + self.max_retries.start_retry_timer() + return super().send(request, **kwargs) + + class TelemetryHelper: """Helper class for getting telemetry related information.""" @@ -146,6 +166,11 @@ class TelemetryClient(BaseTelemetryClient): It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory. """ + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 + TELEMETRY_RETRY_DELAY_MIN = 0.5 # seconds + TELEMETRY_RETRY_DELAY_MAX = 5.0 # seconds + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 + # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" @@ -170,6 +195,18 @@ def __init__( self._host_url = host_url self._executor = executor + self._telemetry_retry_policy = DatabricksRetryPolicy( + delay_min=self.TELEMETRY_RETRY_DELAY_MIN, + delay_max=self.TELEMETRY_RETRY_DELAY_MAX, + stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, + stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, + delay_default=1.0, # Not directly used by telemetry, but required by constructor + force_dangerous_codes=[], # Telemetry doesn't have "dangerous" codes + ) + self._session = requests.Session() + adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy) + self._session.mount("https://", adapter) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -215,7 +252,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") future = self._executor.submit( - requests.post, + self._session.post, url, data=json.dumps(request), headers=headers, @@ -303,6 +340,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._session.close() class TelemetryClientFactory: diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..1305f0ea 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -198,7 +198,7 @@ def test_export_event(self, telemetry_client_setup): client._flush.assert_called_once() assert len(client._events_batch) == 10 - @patch("requests.post") + @patch("requests.Session.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): """Test sending telemetry to the server with authentication.""" client = telemetry_client_setup["client"] @@ -212,12 +212,12 @@ def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): executor.submit.assert_called_once() args, kwargs = executor.submit.call_args - assert args[0] == requests.post + assert args[0] == client._session.post assert kwargs["timeout"] == 10 assert "Authorization" in kwargs["headers"] assert kwargs["headers"]["Authorization"] == "Bearer test-token" - @patch("requests.post") + @patch("requests.Session.post") def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup): """Test sending telemetry to the server without authentication.""" host_url = telemetry_client_setup["host_url"] @@ -239,7 +239,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup) executor.submit.assert_called_once() args, kwargs = executor.submit.call_args - assert args[0] == requests.post + assert args[0] == unauthenticated_client._session.post assert kwargs["timeout"] == 10 assert "Authorization" not in kwargs["headers"] # No auth header assert kwargs["headers"]["Accept"] == "application/json" @@ -331,6 +331,34 @@ class TestBaseClient(BaseTelemetryClient): with pytest.raises(TypeError): TestBaseClient() # Can't instantiate abstract class + def test_telemetry_http_adapter_retry_policy(self, telemetry_client_setup): + """Test that TelemetryHTTPAdapter properly configures DatabricksRetryPolicy.""" + from databricks.sql.telemetry.telemetry_client import TelemetryHTTPAdapter + from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType + + client = telemetry_client_setup["client"] + + # Verify that the session has the TelemetryHTTPAdapter mounted + adapter = client._session.adapters.get("https://") + assert isinstance(adapter, TelemetryHTTPAdapter) + assert isinstance(adapter.max_retries, DatabricksRetryPolicy) + + # Verify that the retry policy has the correct configuration + retry_policy = adapter.max_retries + assert retry_policy.delay_min == client.TELEMETRY_RETRY_DELAY_MIN + assert retry_policy.delay_max == client.TELEMETRY_RETRY_DELAY_MAX + assert retry_policy.stop_after_attempts_count == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT + assert retry_policy.stop_after_attempts_duration == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION + + # Test that the adapter's send method would properly configure the retry policy + # by directly testing the logic that sets command_type and starts the timer + if isinstance(adapter.max_retries, DatabricksRetryPolicy): + adapter.max_retries.command_type = CommandType.OTHER + adapter.max_retries.start_retry_timer() + + # Verify that the retry policy was configured correctly + assert retry_policy.command_type == CommandType.OTHER + class TestTelemetryHelper: """Tests for the TelemetryHelper class.""" From a92821bb62a55643b9e105b8941cd4fe58a4bf5b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 25 Jun 2025 14:26:26 +0530 Subject: [PATCH 2/7] changed error to debug log Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index bcea4ab1..83435db6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -440,7 +440,7 @@ def get_telemetry_client(session_id_hex): if session_id_hex in TelemetryClientFactory._clients: return TelemetryClientFactory._clients[session_id_hex] else: - logger.error( + logger.debug( "Telemetry client not initialized for connection %s", session_id_hex, ) From e684cc3abc2a9549ae73eebb600113be1a215e83 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 27 Jun 2025 10:29:52 +0530 Subject: [PATCH 3/7] added e2e tests to check retry logic in telemetry Signed-off-by: Sai Shree Pradhan --- tests/e2e/test_telemetry_retry.py | 213 ++++++++++++++++++++++++++++++ tests/unit/test_telemetry.py | 17 +-- 2 files changed, 217 insertions(+), 13 deletions(-) create mode 100644 tests/e2e/test_telemetry_retry.py diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py new file mode 100644 index 00000000..57274ed2 --- /dev/null +++ b/tests/e2e/test_telemetry_retry.py @@ -0,0 +1,213 @@ +# tests/e2e/test_telemetry_retry.py + +import pytest +import logging +from unittest.mock import patch, MagicMock +from functools import wraps +import time +from concurrent.futures import Future + +# Imports for the code being tested +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType +from databricks.sql.telemetry.models.enums import AuthMech +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType + +# Imports for mocking the network layer correctly +from urllib3.connectionpool import HTTPSConnectionPool +from urllib3.exceptions import MaxRetryError +from requests.exceptions import ConnectionError as RequestsConnectionError + +PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' + +# Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse. +def create_urllib3_response(status, headers=None, body=b'{}'): + """Create a proper mock response that simulates urllib3's HTTPResponse""" + mock_response = MagicMock() + mock_response.status = status + mock_response.headers = headers or {} + mock_response.msg = headers or {} # For urllib3~=1.0 compatibility + mock_response.data = body + mock_response.read.return_value = body + mock_response.get_redirect_location.return_value = False + mock_response.closed = False + mock_response.isclosed.return_value = False + return mock_response + +@pytest.mark.usefixtures("caplog") +class TestTelemetryClientRetries: + """ + Test suite for verifying the retry mechanism of the TelemetryClient. + This suite patches the low-level urllib3 connection to correctly + trigger and test the retry logic configured in the requests adapter. + """ + + @pytest.fixture(autouse=True) + def setup_and_teardown(self, caplog): + caplog.set_level(logging.DEBUG) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + yield + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + + def get_client(self, session_id, total_retries=3): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test.databricks.com", + ) + client = TelemetryClientFactory.get_telemetry_client(session_id) + + retry_policy = DatabricksRetryPolicy( + delay_min=0.01, + delay_max=0.02, + stop_after_attempts_duration=2.0, + stop_after_attempts_count=total_retries, + delay_default=0.1, + force_dangerous_codes=[], + urllib3_kwargs={'total': total_retries} + ) + adapter = client._session.adapters.get("https://") + adapter.max_retries = retry_policy + return client, adapter + + def wait_for_async_request(self, timeout=2.0): + """Wait for async telemetry request to complete""" + start_time = time.time() + while time.time() - start_time < timeout: + if TelemetryClientFactory._executor and TelemetryClientFactory._executor._threads: + # Wait a bit more for threads to complete + time.sleep(0.1) + else: + break + time.sleep(0.1) # Extra buffer for completion + + def test_success_no_retry(self): + client, _ = self.get_client("session-success") + params = DriverConnectionParameters( + http_path="test-path", + mode=DatabricksClientType.THRIFT, + host_info=HostDetails(host_url="test.databricks.com", port=443), + auth_mech=AuthMech.PAT + ) + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) + + client.export_initial_telemetry_log(params, "test-agent") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + mock_get_conn.return_value.getresponse.assert_called_once() + + def test_retry_on_503_then_succeeds(self): + client, _ = self.get_client("session-retry-once") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.side_effect = [ + create_urllib3_response(503), + create_urllib3_response(200), + ] + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + assert mock_get_conn.return_value.getresponse.call_count == 2 + + def test_respects_retry_after_header(self, caplog): + client, _ = self.get_client("session-retry-after") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.side_effect = [ + create_urllib3_response(429, headers={'Retry-After': '1'}), # Use integer seconds to avoid parsing issues + create_urllib3_response(200) + ] + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # Check that the request was retried (should be 2 calls: initial + 1 retry) + assert mock_get_conn.return_value.getresponse.call_count == 2 + assert "Retrying after" in caplog.text + + def test_exceeds_retry_count_limit(self, caplog): + client, _ = self.get_client("session-exceed-limit", total_retries=3) + expected_call_count = 4 + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(503) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + assert mock_get_conn.return_value.getresponse.call_count == expected_call_count + assert "Telemetry request failed with exception" in caplog.text + assert "Max retries exceeded" in caplog.text + + def test_no_retry_on_401_unauthorized(self, caplog): + """Test that 401 responses are not retried (per retry policy)""" + client, _ = self.get_client("session-401") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(401) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # 401 should not be retried based on the retry policy + mock_get_conn.return_value.getresponse.assert_called_once() + assert "Telemetry request failed with status code: 401" in caplog.text + + def test_retries_on_400_bad_request(self, caplog): + """Test that 400 responses are retried (this is the current behavior for telemetry)""" + client, _ = self.get_client("session-400") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(400) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER) + expected_call_count = 4 # total + 1 (initial + 3 retries) + assert mock_get_conn.return_value.getresponse.call_count == expected_call_count + assert "Telemetry request failed with exception" in caplog.text + assert "Max retries exceeded" in caplog.text + + def test_no_retry_on_403_forbidden(self, caplog): + """Test that 403 responses are not retried (per retry policy)""" + client, _ = self.get_client("session-403") + with patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(403) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + # 403 should not be retried based on the retry policy + mock_get_conn.return_value.getresponse.assert_called_once() + assert "Telemetry request failed with status code: 403" in caplog.text + + def test_retry_policy_command_type_is_set_to_other(self): + client, adapter = self.get_client("session-command-type") + + original_send = adapter.send + @wraps(original_send) + def wrapper(request, **kwargs): + assert adapter.max_retries.command_type == CommandType.OTHER + return original_send(request, **kwargs) + + with patch.object(adapter, 'send', side_effect=wrapper, autospec=True), \ + patch(PATCH_TARGET) as mock_get_conn: + mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) + + client.export_failure_log("TestError", "Test message") + self.wait_for_async_request() + TelemetryClientFactory.close(client._session_id_hex) + + assert adapter.send.call_count == 1 \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 1305f0ea..8a8f974c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -331,10 +331,10 @@ class TestBaseClient(BaseTelemetryClient): with pytest.raises(TypeError): TestBaseClient() # Can't instantiate abstract class - def test_telemetry_http_adapter_retry_policy(self, telemetry_client_setup): - """Test that TelemetryHTTPAdapter properly configures DatabricksRetryPolicy.""" + def test_telemetry_http_adapter_configuration(self, telemetry_client_setup): + """Test that TelemetryHTTPAdapter is properly configured with correct retry parameters.""" from databricks.sql.telemetry.telemetry_client import TelemetryHTTPAdapter - from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType + from databricks.sql.auth.retry import DatabricksRetryPolicy client = telemetry_client_setup["client"] @@ -343,21 +343,12 @@ def test_telemetry_http_adapter_retry_policy(self, telemetry_client_setup): assert isinstance(adapter, TelemetryHTTPAdapter) assert isinstance(adapter.max_retries, DatabricksRetryPolicy) - # Verify that the retry policy has the correct configuration + # Verify that the retry policy has the correct static configuration retry_policy = adapter.max_retries assert retry_policy.delay_min == client.TELEMETRY_RETRY_DELAY_MIN assert retry_policy.delay_max == client.TELEMETRY_RETRY_DELAY_MAX assert retry_policy.stop_after_attempts_count == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT assert retry_policy.stop_after_attempts_duration == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION - - # Test that the adapter's send method would properly configure the retry policy - # by directly testing the logic that sets command_type and starts the timer - if isinstance(adapter.max_retries, DatabricksRetryPolicy): - adapter.max_retries.command_type = CommandType.OTHER - adapter.max_retries.start_retry_timer() - - # Verify that the retry policy was configured correctly - assert retry_policy.command_type == CommandType.OTHER class TestTelemetryHelper: From d135406c8c5ee6608116f3ec9852f08e6eee6015 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 2 Jul 2025 14:24:56 +0530 Subject: [PATCH 4/7] removed caplog Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/exc.py | 10 +- .../sql/telemetry/telemetry_client.py | 9 +- tests/e2e/test_telemetry_retry.py | 209 ++++++------------ tests/unit/test_telemetry.py | 20 -- 4 files changed, 72 insertions(+), 176 deletions(-) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index fbdbe6a5..4a772c49 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -19,11 +19,13 @@ def __init__( self.context = context or {} error_name = self.__class__.__name__ + if session_id_hex: + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - - telemetry_client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - telemetry_client.export_failure_log(error_name, self.message) + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 83435db6..62c6a5fc 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -41,13 +41,8 @@ class TelemetryHTTPAdapter(HTTPAdapter): """ def send(self, request, **kwargs): - # The DatabricksRetryPolicy needs state set before the first attempt. - if isinstance(self.max_retries, DatabricksRetryPolicy): - # Telemetry requests are idempotent and safe to retry. We use CommandType.OTHER - # to signal this to the retry policy, bypassing stricter rules for commands - # like ExecuteStatement. - self.max_retries.command_type = CommandType.OTHER - self.max_retries.start_retry_timer() + self.max_retries.command_type = CommandType.OTHER + self.max_retries.start_retry_timer() return super().send(request, **kwargs) diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py index 57274ed2..cbbcb8ad 100644 --- a/tests/e2e/test_telemetry_retry.py +++ b/tests/e2e/test_telemetry_retry.py @@ -1,50 +1,34 @@ -# tests/e2e/test_telemetry_retry.py - import pytest -import logging from unittest.mock import patch, MagicMock -from functools import wraps -import time -from concurrent.futures import Future +import io -# Imports for the code being tested from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType from databricks.sql.telemetry.models.enums import AuthMech -from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType - -# Imports for mocking the network layer correctly -from urllib3.connectionpool import HTTPSConnectionPool -from urllib3.exceptions import MaxRetryError -from requests.exceptions import ConnectionError as RequestsConnectionError +from databricks.sql.auth.retry import DatabricksRetryPolicy PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' -# Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse. -def create_urllib3_response(status, headers=None, body=b'{}'): - """Create a proper mock response that simulates urllib3's HTTPResponse""" - mock_response = MagicMock() - mock_response.status = status - mock_response.headers = headers or {} - mock_response.msg = headers or {} # For urllib3~=1.0 compatibility - mock_response.data = body - mock_response.read.return_value = body - mock_response.get_redirect_location.return_value = False - mock_response.closed = False - mock_response.isclosed.return_value = False - return mock_response +def create_mock_conn(responses): + """Creates a mock connection object whose getresponse() method yields a series of responses.""" + mock_conn = MagicMock() + mock_http_responses = [] + for resp in responses: + mock_http_response = MagicMock() + mock_http_response.status = resp.get("status") + mock_http_response.headers = resp.get("headers", {}) + body = resp.get("body", b'{}') + mock_http_response.fp = io.BytesIO(body) + def release(): + mock_http_response.fp.close() + mock_http_response.release_conn = release + mock_http_responses.append(mock_http_response) + mock_conn.getresponse.side_effect = mock_http_responses + return mock_conn -@pytest.mark.usefixtures("caplog") class TestTelemetryClientRetries: - """ - Test suite for verifying the retry mechanism of the TelemetryClient. - This suite patches the low-level urllib3 connection to correctly - trigger and test the retry logic configured in the requests adapter. - """ - @pytest.fixture(autouse=True) - def setup_and_teardown(self, caplog): - caplog.set_level(logging.DEBUG) + def setup_and_teardown(self): TelemetryClientFactory._initialized = False TelemetryClientFactory._clients = {} TelemetryClientFactory._executor = None @@ -55,7 +39,10 @@ def setup_and_teardown(self, caplog): TelemetryClientFactory._clients = {} TelemetryClientFactory._executor = None - def get_client(self, session_id, total_retries=3): + def get_client(self, session_id, num_retries=3): + """ + Configures a client with a specific number of retries. + """ TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, @@ -63,151 +50,83 @@ def get_client(self, session_id, total_retries=3): host_url="test.databricks.com", ) client = TelemetryClientFactory.get_telemetry_client(session_id) - + retry_policy = DatabricksRetryPolicy( delay_min=0.01, delay_max=0.02, stop_after_attempts_duration=2.0, - stop_after_attempts_count=total_retries, + stop_after_attempts_count=num_retries, delay_default=0.1, force_dangerous_codes=[], - urllib3_kwargs={'total': total_retries} + urllib3_kwargs={'total': num_retries} ) adapter = client._session.adapters.get("https://") adapter.max_retries = retry_policy return client, adapter - def wait_for_async_request(self, timeout=2.0): - """Wait for async telemetry request to complete""" - start_time = time.time() - while time.time() - start_time < timeout: - if TelemetryClientFactory._executor and TelemetryClientFactory._executor._threads: - # Wait a bit more for threads to complete - time.sleep(0.1) - else: - break - time.sleep(0.1) # Extra buffer for completion - def test_success_no_retry(self): client, _ = self.get_client("session-success") params = DriverConnectionParameters( - http_path="test-path", - mode=DatabricksClientType.THRIFT, + http_path="test-path", mode=DatabricksClientType.THRIFT, host_info=HostDetails(host_url="test.databricks.com", port=443), auth_mech=AuthMech.PAT ) - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) - + mock_responses = [{"status": 200}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: client.export_initial_telemetry_log(params, "test-agent") - self.wait_for_async_request() TelemetryClientFactory.close(client._session_id_hex) mock_get_conn.return_value.getresponse.assert_called_once() - - def test_retry_on_503_then_succeeds(self): - client, _ = self.get_client("session-retry-once") - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.side_effect = [ - create_urllib3_response(503), - create_urllib3_response(200), - ] - - client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() - TelemetryClientFactory.close(client._session_id_hex) - - assert mock_get_conn.return_value.getresponse.call_count == 2 - - def test_respects_retry_after_header(self, caplog): - client, _ = self.get_client("session-retry-after") - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.side_effect = [ - create_urllib3_response(429, headers={'Retry-After': '1'}), # Use integer seconds to avoid parsing issues - create_urllib3_response(200) - ] - + client, _ = self.get_client("session-retry-once", num_retries=1) + mock_responses = [{"status": 503}, {"status": 200}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() TelemetryClientFactory.close(client._session_id_hex) - # Check that the request was retried (should be 2 calls: initial + 1 retry) assert mock_get_conn.return_value.getresponse.call_count == 2 - assert "Retrying after" in caplog.text - def test_exceeds_retry_count_limit(self, caplog): - client, _ = self.get_client("session-exceed-limit", total_retries=3) - expected_call_count = 4 - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(503) - + @pytest.mark.parametrize( + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + ], + ) + def test_non_retryable_status_codes_are_not_retried(self, status_code, description): + """ + Verifies that terminal error codes (401, 403, 501, etc.) are not retried. + """ + # Use the status code in the session ID for easier debugging if it fails + client, _ = self.get_client(f"session-{status_code}") + mock_responses = [{"status": status_code}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() TelemetryClientFactory.close(client._session_id_hex) - - assert mock_get_conn.return_value.getresponse.call_count == expected_call_count - assert "Telemetry request failed with exception" in caplog.text - assert "Max retries exceeded" in caplog.text - def test_no_retry_on_401_unauthorized(self, caplog): - """Test that 401 responses are not retried (per retry policy)""" - client, _ = self.get_client("session-401") - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(401) - - client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() - TelemetryClientFactory.close(client._session_id_hex) - - # 401 should not be retried based on the retry policy mock_get_conn.return_value.getresponse.assert_called_once() - assert "Telemetry request failed with status code: 401" in caplog.text - def test_retries_on_400_bad_request(self, caplog): - """Test that 400 responses are retried (this is the current behavior for telemetry)""" - client, _ = self.get_client("session-400") - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(400) - - client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() - TelemetryClientFactory.close(client._session_id_hex) - - # Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER) - expected_call_count = 4 # total + 1 (initial + 3 retries) - assert mock_get_conn.return_value.getresponse.call_count == expected_call_count - assert "Telemetry request failed with exception" in caplog.text - assert "Max retries exceeded" in caplog.text - - def test_no_retry_on_403_forbidden(self, caplog): - """Test that 403 responses are not retried (per retry policy)""" - client, _ = self.get_client("session-403") - with patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(403) - + def test_respects_retry_after_header(self): + client, _ = self.get_client("session-retry-after", num_retries=1) + mock_responses = [{"status": 429, "headers": {'Retry-After': '1'}}, {"status": 200}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() TelemetryClientFactory.close(client._session_id_hex) - # 403 should not be retried based on the retry policy - mock_get_conn.return_value.getresponse.assert_called_once() - assert "Telemetry request failed with status code: 403" in caplog.text + assert mock_get_conn.return_value.getresponse.call_count == 2 - def test_retry_policy_command_type_is_set_to_other(self): - client, adapter = self.get_client("session-command-type") + def test_exceeds_retry_count_limit(self): + num_retries = 3 + expected_total_calls = num_retries + 1 + client, _ = self.get_client("session-exceed-limit", num_retries=num_retries) + mock_responses = [{"status": 503}] * expected_total_calls - original_send = adapter.send - @wraps(original_send) - def wrapper(request, **kwargs): - assert adapter.max_retries.command_type == CommandType.OTHER - return original_send(request, **kwargs) - - with patch.object(adapter, 'send', side_effect=wrapper, autospec=True), \ - patch(PATCH_TARGET) as mock_get_conn: - mock_get_conn.return_value.getresponse.return_value = create_urllib3_response(200) - + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: client.export_failure_log("TestError", "Test message") - self.wait_for_async_request() TelemetryClientFactory.close(client._session_id_hex) - assert adapter.send.call_count == 1 \ No newline at end of file + assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 8a8f974c..f7edc89c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -331,26 +331,6 @@ class TestBaseClient(BaseTelemetryClient): with pytest.raises(TypeError): TestBaseClient() # Can't instantiate abstract class - def test_telemetry_http_adapter_configuration(self, telemetry_client_setup): - """Test that TelemetryHTTPAdapter is properly configured with correct retry parameters.""" - from databricks.sql.telemetry.telemetry_client import TelemetryHTTPAdapter - from databricks.sql.auth.retry import DatabricksRetryPolicy - - client = telemetry_client_setup["client"] - - # Verify that the session has the TelemetryHTTPAdapter mounted - adapter = client._session.adapters.get("https://") - assert isinstance(adapter, TelemetryHTTPAdapter) - assert isinstance(adapter.max_retries, DatabricksRetryPolicy) - - # Verify that the retry policy has the correct static configuration - retry_policy = adapter.max_retries - assert retry_policy.delay_min == client.TELEMETRY_RETRY_DELAY_MIN - assert retry_policy.delay_max == client.TELEMETRY_RETRY_DELAY_MAX - assert retry_policy.stop_after_attempts_count == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT - assert retry_policy.stop_after_attempts_duration == client.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION - - class TestTelemetryHelper: """Tests for the TelemetryHelper class.""" From 48345ec18725a79b95b5f57934ee42f36cdef09b Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 2 Jul 2025 14:55:59 +0530 Subject: [PATCH 5/7] retry policy default values Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 62c6a5fc..986c7f89 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -162,8 +162,8 @@ class TelemetryClient(BaseTelemetryClient): """ TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 - TELEMETRY_RETRY_DELAY_MIN = 0.5 # seconds - TELEMETRY_RETRY_DELAY_MAX = 5.0 # seconds + TELEMETRY_RETRY_DELAY_MIN = 1.0 + TELEMETRY_RETRY_DELAY_MAX = 10.0 TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 # Telemetry endpoint paths @@ -195,8 +195,8 @@ def __init__( delay_max=self.TELEMETRY_RETRY_DELAY_MAX, stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, - delay_default=1.0, # Not directly used by telemetry, but required by constructor - force_dangerous_codes=[], # Telemetry doesn't have "dangerous" codes + delay_default=1.0, + force_dangerous_codes=[], ) self._session = requests.Session() adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy) From e7d277994e3d56fc60dbe4a15a23d4be9317c606 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 2 Jul 2025 15:01:30 +0530 Subject: [PATCH 6/7] formatting Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 986c7f89..32a065ab 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -162,8 +162,8 @@ class TelemetryClient(BaseTelemetryClient): """ TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 - TELEMETRY_RETRY_DELAY_MIN = 1.0 - TELEMETRY_RETRY_DELAY_MAX = 10.0 + TELEMETRY_RETRY_DELAY_MIN = 1.0 + TELEMETRY_RETRY_DELAY_MAX = 10.0 TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 # Telemetry endpoint paths @@ -195,8 +195,8 @@ def __init__( delay_max=self.TELEMETRY_RETRY_DELAY_MAX, stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, - delay_default=1.0, - force_dangerous_codes=[], + delay_default=1.0, + force_dangerous_codes=[], ) self._session = requests.Session() adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy) From 7c64c7b4cfe5ded0926b9e4cb8a004ccaa7bf11f Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Thu, 3 Jul 2025 11:50:57 +0530 Subject: [PATCH 7/7] compact tests Signed-off-by: Sai Shree Pradhan --- tests/e2e/test_telemetry_retry.py | 51 ++++++++----------------------- 1 file changed, 13 insertions(+), 38 deletions(-) diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py index cbbcb8ad..70089b7d 100644 --- a/tests/e2e/test_telemetry_retry.py +++ b/tests/e2e/test_telemetry_retry.py @@ -1,10 +1,9 @@ import pytest from unittest.mock import patch, MagicMock import io +import time from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory -from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType -from databricks.sql.telemetry.models.enums import AuthMech from databricks.sql.auth.retry import DatabricksRetryPolicy PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' @@ -64,40 +63,18 @@ def get_client(self, session_id, num_retries=3): adapter.max_retries = retry_policy return client, adapter - def test_success_no_retry(self): - client, _ = self.get_client("session-success") - params = DriverConnectionParameters( - http_path="test-path", mode=DatabricksClientType.THRIFT, - host_info=HostDetails(host_url="test.databricks.com", port=443), - auth_mech=AuthMech.PAT - ) - mock_responses = [{"status": 200}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: - client.export_initial_telemetry_log(params, "test-agent") - TelemetryClientFactory.close(client._session_id_hex) - - mock_get_conn.return_value.getresponse.assert_called_once() - client, _ = self.get_client("session-retry-once", num_retries=1) - mock_responses = [{"status": 503}, {"status": 200}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: - client.export_failure_log("TestError", "Test message") - TelemetryClientFactory.close(client._session_id_hex) - - assert mock_get_conn.return_value.getresponse.call_count == 2 - @pytest.mark.parametrize( "status_code, description", [ (401, "Unauthorized"), (403, "Forbidden"), (501, "Not Implemented"), + (200, "Success"), ], ) def test_non_retryable_status_codes_are_not_retried(self, status_code, description): """ - Verifies that terminal error codes (401, 403, 501, etc.) are not retried. + Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. """ # Use the status code in the session ID for easier debugging if it fails client, _ = self.get_client(f"session-{status_code}") @@ -109,24 +86,22 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti mock_get_conn.return_value.getresponse.assert_called_once() - def test_respects_retry_after_header(self): - client, _ = self.get_client("session-retry-after", num_retries=1) - mock_responses = [{"status": 429, "headers": {'Retry-After': '1'}}, {"status": 200}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: - client.export_failure_log("TestError", "Test message") - TelemetryClientFactory.close(client._session_id_hex) - - assert mock_get_conn.return_value.getresponse.call_count == 2 - def test_exceeds_retry_count_limit(self): + """ + Verifies that the client retries up to the specified number of times before giving up. + Verifies that the client respects the Retry-After header and retries on 429, 502, 503. + """ num_retries = 3 expected_total_calls = num_retries + 1 + retry_after = 1 client, _ = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [{"status": 503}] * expected_total_calls + mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + start_time = time.time() client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) + end_time = time.time() - assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls \ No newline at end of file + assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls + assert end_time - start_time > retry_after \ No newline at end of file