Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add retry mechanism to telemetry requests #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: telemetry
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import logging

from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory

logger = logging.getLogger(__name__)

### PEP-249 Mandated ###
Expand All @@ -22,6 +20,8 @@ def __init__(

error_name = self.__class__.__name__
if session_id_hex:
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory

telemetry_client = TelemetryClientFactory.get_telemetry_client(
session_id_hex
)
Expand Down
37 changes: 35 additions & 2 deletions src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
DatabricksOAuthProvider,
ExternalAuthProvider,
)
from requests.adapters import HTTPAdapter
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
import sys
import platform
import uuid
Expand All @@ -31,6 +33,19 @@
logger = logging.getLogger(__name__)


class TelemetryHTTPAdapter(HTTPAdapter):
"""
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
This ensures the retry timer is started and the command type is set correctly,
allowing the policy to manage its state for the duration of the request retries.
"""

def send(self, request, **kwargs):
self.max_retries.command_type = CommandType.OTHER
self.max_retries.start_retry_timer()
return super().send(request, **kwargs)


class TelemetryHelper:
"""Helper class for getting telemetry related information."""

Expand Down Expand Up @@ -146,6 +161,11 @@ class TelemetryClient(BaseTelemetryClient):
It uses a thread pool to handle asynchronous operations, that it gets from the TelemetryClientFactory.
"""

TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
TELEMETRY_RETRY_DELAY_MIN = 1.0
TELEMETRY_RETRY_DELAY_MAX = 10.0
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0

# Telemetry endpoint paths
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
Expand All @@ -170,6 +190,18 @@ def __init__(
self._host_url = host_url
self._executor = executor

self._telemetry_retry_policy = DatabricksRetryPolicy(
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
delay_default=1.0,
force_dangerous_codes=[],
)
self._session = requests.Session()
adapter = TelemetryHTTPAdapter(max_retries=self._telemetry_retry_policy)
self._session.mount("https://", adapter)

def _export_event(self, event):
"""Add an event to the batch queue and flush if batch is full"""
logger.debug("Exporting event for connection %s", self._session_id_hex)
Expand Down Expand Up @@ -215,7 +247,7 @@ def _send_telemetry(self, events):
try:
logger.debug("Submitting telemetry request to thread pool")
future = self._executor.submit(
requests.post,
self._session.post,
url,
data=json.dumps(request),
headers=headers,
Expand Down Expand Up @@ -303,6 +335,7 @@ def close(self):
"""Flush remaining events before closing"""
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
self._flush()
self._session.close()


class TelemetryClientFactory:
Expand Down Expand Up @@ -402,7 +435,7 @@ def get_telemetry_client(session_id_hex):
if session_id_hex in TelemetryClientFactory._clients:
return TelemetryClientFactory._clients[session_id_hex]
else:
logger.error(
logger.debug(
"Telemetry client not initialized for connection %s",
session_id_hex,
)
Expand Down
107 changes: 107 additions & 0 deletions tests/e2e/test_telemetry_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
from unittest.mock import patch, MagicMock
import io
import time

from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
from databricks.sql.auth.retry import DatabricksRetryPolicy

PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'

def create_mock_conn(responses):
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
mock_conn = MagicMock()
mock_http_responses = []
for resp in responses:
mock_http_response = MagicMock()
mock_http_response.status = resp.get("status")
mock_http_response.headers = resp.get("headers", {})
body = resp.get("body", b'{}')
mock_http_response.fp = io.BytesIO(body)
def release():
mock_http_response.fp.close()
mock_http_response.release_conn = release
mock_http_responses.append(mock_http_response)
mock_conn.getresponse.side_effect = mock_http_responses
return mock_conn

class TestTelemetryClientRetries:
@pytest.fixture(autouse=True)
def setup_and_teardown(self):
TelemetryClientFactory._initialized = False
TelemetryClientFactory._clients = {}
TelemetryClientFactory._executor = None
yield
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._initialized = False
TelemetryClientFactory._clients = {}
TelemetryClientFactory._executor = None

def get_client(self, session_id, num_retries=3):
"""
Configures a client with a specific number of retries.
"""
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True,
session_id_hex=session_id,
auth_provider=None,
host_url="test.databricks.com",
)
client = TelemetryClientFactory.get_telemetry_client(session_id)

retry_policy = DatabricksRetryPolicy(
delay_min=0.01,
delay_max=0.02,
stop_after_attempts_duration=2.0,
stop_after_attempts_count=num_retries,
delay_default=0.1,
force_dangerous_codes=[],
urllib3_kwargs={'total': num_retries}
)
adapter = client._session.adapters.get("https://")
adapter.max_retries = retry_policy
return client, adapter

@pytest.mark.parametrize(
"status_code, description",
[
(401, "Unauthorized"),
(403, "Forbidden"),
(501, "Not Implemented"),
(200, "Success"),
],
)
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
"""
Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried.
"""
# Use the status code in the session ID for easier debugging if it fails
client, _ = self.get_client(f"session-{status_code}")
mock_responses = [{"status": status_code}]

with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
client.export_failure_log("TestError", "Test message")
TelemetryClientFactory.close(client._session_id_hex)

mock_get_conn.return_value.getresponse.assert_called_once()

def test_exceeds_retry_count_limit(self):
"""
Verifies that the client retries up to the specified number of times before giving up.
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
"""
num_retries = 3
expected_total_calls = num_retries + 1
retry_after = 1
client, _ = self.get_client("session-exceed-limit", num_retries=num_retries)
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]

with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
start_time = time.time()
client.export_failure_log("TestError", "Test message")
TelemetryClientFactory.close(client._session_id_hex)
end_time = time.time()

assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
assert end_time - start_time > retry_after
9 changes: 4 additions & 5 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_export_event(self, telemetry_client_setup):
client._flush.assert_called_once()
assert len(client._events_batch) == 10

@patch("requests.post")
@patch("requests.Session.post")
def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):
"""Test sending telemetry to the server with authentication."""
client = telemetry_client_setup["client"]
Expand All @@ -212,12 +212,12 @@ def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):

executor.submit.assert_called_once()
args, kwargs = executor.submit.call_args
assert args[0] == requests.post
assert args[0] == client._session.post
assert kwargs["timeout"] == 10
assert "Authorization" in kwargs["headers"]
assert kwargs["headers"]["Authorization"] == "Bearer test-token"

@patch("requests.post")
@patch("requests.Session.post")
def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup):
"""Test sending telemetry to the server without authentication."""
host_url = telemetry_client_setup["host_url"]
Expand All @@ -239,7 +239,7 @@ def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup)

executor.submit.assert_called_once()
args, kwargs = executor.submit.call_args
assert args[0] == requests.post
assert args[0] == unauthenticated_client._session.post
assert kwargs["timeout"] == 10
assert "Authorization" not in kwargs["headers"] # No auth header
assert kwargs["headers"]["Accept"] == "application/json"
Expand Down Expand Up @@ -331,7 +331,6 @@ class TestBaseClient(BaseTelemetryClient):
with pytest.raises(TypeError):
TestBaseClient() # Can't instantiate abstract class


class TestTelemetryHelper:
"""Tests for the TelemetryHelper class."""

Expand Down
Loading