diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 30fd6c26..4a772c49 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -1,8 +1,6 @@ import json import logging -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - logger = logging.getLogger(__name__) ### PEP-249 Mandated ### @@ -22,6 +20,8 @@ def __init__( error_name = self.__class__.__name__ if session_id_hex: + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory + telemetry_client = TelemetryClientFactory.get_telemetry_client( session_id_hex ) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..32a065ab 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -22,6 +22,8 @@ DatabricksOAuthProvider, ExternalAuthProvider, ) +from requests.adapters import HTTPAdapter +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType import sys import platform import uuid @@ -31,6 +33,19 @@ logger = logging.getLogger(__name__) +class TelemetryHTTPAdapter(HTTPAdapter): + """ + Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. + This ensures the retry timer is started and the command type is set correctly, + allowing the policy to manage its state for the duration of the request retries. + """ + + def send(self, request, **kwargs): + self.max_retries.command_type = CommandType.OTHER + self.max_retries.start_retry_timer() + return super().send(request, **kwargs) + + class TelemetryHelper: """Helper class for getting telemetry related information.""" @@ -146,6 +161,11 @@ class TelemetryClient(BaseTelemetryClient): It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory. """ + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 + TELEMETRY_RETRY_DELAY_MIN = 1.0 + TELEMETRY_RETRY_DELAY_MAX = 10.0 + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 + # Telemetry endpoint paths TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" @@ -170,6 +190,18 @@ def __init__( self._host_url = host_url self._executor = executor + self._telemetry_retry_policy = DatabricksRetryPolicy( + delay_min=self.TELEMETRY_RETRY_DELAY_MIN, + delay_max=self.TELEMETRY_RETRY_DELAY_MAX, + stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, + stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, + delay_default=1.0, + force_dangerous_codes=[], + ) + self._session = requests.Session() + adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy) + self._session.mount("https://", adapter) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -215,7 +247,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") future = self._executor.submit( - requests.post, + self._session.post, url, data=json.dumps(request), headers=headers, @@ -303,6 +335,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._session.close() class TelemetryClientFactory: @@ -402,7 +435,7 @@ def get_telemetry_client(session_id_hex): if session_id_hex in TelemetryClientFactory._clients: return TelemetryClientFactory._clients[session_id_hex] else: - logger.error( + logger.debug( "Telemetry client not initialized for connection %s", session_id_hex, ) diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py new file mode 100644 index 00000000..70089b7d --- /dev/null +++ b/tests/e2e/test_telemetry_retry.py @@ -0,0 +1,107 @@ +import pytest +from unittest.mock import patch, MagicMock +import io +import time + +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.auth.retry import DatabricksRetryPolicy + +PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' + +def create_mock_conn(responses): + """Creates a mock connection object whose getresponse() method yields a series of responses.""" + mock_conn = MagicMock() + mock_http_responses = [] + for resp in responses: + mock_http_response = MagicMock() + mock_http_response.status = resp.get("status") + mock_http_response.headers = resp.get("headers", {}) + body = resp.get("body", b'{}') + mock_http_response.fp = io.BytesIO(body) + def release(): + mock_http_response.fp.close() + mock_http_response.release_conn = release + mock_http_responses.append(mock_http_response) + mock_conn.getresponse.side_effect = mock_http_responses + return mock_conn + +class TestTelemetryClientRetries: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + yield + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + + def get_client(self, session_id, num_retries=3): + """ + Configures a client with a specific number of retries. + """ + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test.databricks.com", + ) + client = TelemetryClientFactory.get_telemetry_client(session_id) + + retry_policy = DatabricksRetryPolicy( + delay_min=0.01, + delay_max=0.02, + stop_after_attempts_duration=2.0, + stop_after_attempts_count=num_retries, + delay_default=0.1, + force_dangerous_codes=[], + urllib3_kwargs={'total': num_retries} + ) + adapter = client._session.adapters.get("https://") + adapter.max_retries = retry_policy + return client, adapter + + @pytest.mark.parametrize( + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], + ) + def test_non_retryable_status_codes_are_not_retried(self, status_code, description): + """ + Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. + """ + # Use the status code in the session ID for easier debugging if it fails + client, _ = self.get_client(f"session-{status_code}") + mock_responses = [{"status": status_code}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + + mock_get_conn.return_value.getresponse.assert_called_once() + + def test_exceeds_retry_count_limit(self): + """ + Verifies that the client retries up to the specified number of times before giving up. + Verifies that the client respects the Retry-After header and retries on 429, 502, 503. + """ + num_retries = 3 + expected_total_calls = num_retries + 1 + retry_after = 1 + client, _ = self.get_client("session-exceed-limit", num_retries=num_retries) + mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + start_time = time.time() + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + end_time = time.time() + + assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls + assert end_time - start_time > retry_after \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..f7edc89c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -198,7 +198,7 @@ def test_export_event(self, telemetry_client_setup): client._flush.assert_called_once() assert len(client._events_batch) == 10 - @patch("requests.post") + @patch("requests.Session.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): """Test sending telemetry to the server with authentication.""" client = telemetry_client_setup["client"] @@ -212,12 +212,12 @@ def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): executor.submit.assert_called_once() args, kwargs = executor.submit.call_args - assert args[0] == requests.post + assert args[0] == client._session.post assert kwargs["timeout"] == 10 assert "Authorization" in kwargs["headers"] assert kwargs["headers"]["Authorization"] == "Bearer test-token" - @patch("requests.post") + @patch("requests.Session.post") def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup): """Test sending telemetry to the server without authentication.""" host_url = telemetry_client_setup["host_url"] @@ -239,7 +239,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup) executor.submit.assert_called_once() args, kwargs = executor.submit.call_args - assert args[0] == requests.post + assert args[0] == unauthenticated_client._session.post assert kwargs["timeout"] == 10 assert "Authorization" not in kwargs["headers"] # No auth header assert kwargs["headers"]["Accept"] == "application/json" @@ -331,7 +331,6 @@ class TestBaseClient(BaseTelemetryClient): with pytest.raises(TypeError): TestBaseClient() # Can't instantiate abstract class - class TestTelemetryHelper: """Tests for the TelemetryHelper class."""