Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import logging
import threading
from abc import abstractmethod
from datetime import timedelta
from json import JSONDecodeError
Expand Down Expand Up @@ -41,6 +42,13 @@ class AbstractOauth2Authenticator(AuthBase):

_NO_STREAM_NAME = None

# Class-level lock to prevent concurrent token refresh across multiple authenticator instances.
# This is necessary because multiple streams may share the same OAuth credentials (refresh token)
# through the connector config. Without this lock, concurrent refresh attempts can cause race
# conditions where one stream successfully refreshes the token while others fail because the
# refresh token has been invalidated (especially for single-use refresh tokens).
_token_refresh_lock: threading.Lock = threading.Lock()

def __init__(
self,
refresh_token_error_status_codes: Tuple[int, ...] = (),
Expand Down Expand Up @@ -86,9 +94,19 @@ def get_auth_header(self) -> Mapping[str, Any]:
return {"Authorization": f"Bearer {token}"}

def get_access_token(self) -> str:
"""Returns the access token"""
"""
Returns the access token.

This method uses double-checked locking to ensure thread-safe token refresh.
When multiple threads (streams) detect an expired token simultaneously, only one
will perform the refresh while others wait. After acquiring the lock, the token
expiry is re-checked to avoid redundant refresh attempts.
"""
if self.token_has_expired():
self.refresh_and_set_access_token()
with self._token_refresh_lock:
# Double-check after acquiring lock - another thread may have already refreshed
if self.token_has_expired():
self.refresh_and_set_access_token()

return self.access_token

Expand Down
10 changes: 9 additions & 1 deletion airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,21 @@ def token_has_expired(self) -> bool:
def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.

This method uses double-checked locking to ensure thread-safe token refresh.
This is especially critical for single-use refresh tokens where concurrent
refresh attempts would cause failures as the refresh token is invalidated
after first use.

The new refresh token is persisted with the set_refresh_token function.

Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
self.refresh_and_set_access_token()
with self._token_refresh_lock:
# Double-check after acquiring lock - another thread may have already refreshed
if self.token_has_expired():
self.refresh_and_set_access_token()
return self.access_token

def refresh_and_set_access_token(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import json
import logging
import threading
import time
from datetime import timedelta
from typing import Optional, Union
from unittest.mock import Mock
Expand Down Expand Up @@ -785,3 +787,132 @@ def mock_request(method, url, data, headers):
raise Exception(
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
)


class TestConcurrentTokenRefresh:
"""
Test class for verifying thread-safe token refresh behavior.

These tests ensure that when multiple threads (streams) attempt to refresh
an expired token simultaneously, only one refresh actually occurs and
others wait and use the refreshed token.
"""

def test_concurrent_token_refresh_only_refreshes_once(self, mocker):
"""
When multiple threads detect an expired token and try to refresh simultaneously,
only one thread should actually perform the refresh. Others should wait and
use the newly refreshed token.
"""
refresh_call_count = 0
refresh_call_lock = threading.Lock()

def mock_refresh_access_token(self):
nonlocal refresh_call_count
with refresh_call_lock:
refresh_call_count += 1
time.sleep(0.1)
return ("new_access_token", ab_datetime_now() + timedelta(hours=1))

mocker.patch.object(
Oauth2Authenticator,
"refresh_access_token",
mock_refresh_access_token,
)

oauth = Oauth2Authenticator(
token_refresh_endpoint="https://refresh_endpoint.com",
client_id="client_id",
client_secret="client_secret",
refresh_token="refresh_token",
token_expiry_date=ab_datetime_now() - timedelta(hours=1),
)

results = []
errors = []

def get_token():
try:
token = oauth.get_access_token()
results.append(token)
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=get_token) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

assert len(errors) == 0, f"Unexpected errors: {errors}"
assert len(results) == 5
assert all(token == "new_access_token" for token in results)
assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}"

def test_single_use_refresh_token_concurrent_refresh_only_refreshes_once(self, mocker):
"""
For SingleUseRefreshTokenOauth2Authenticator, concurrent refresh attempts
should also only result in one actual refresh to prevent invalidating
the single-use refresh token.
"""
refresh_call_count = 0
refresh_call_lock = threading.Lock()

connector_config = {
"credentials": {
"client_id": "client_id",
"client_secret": "client_secret",
"refresh_token": "refresh_token",
"access_token": "old_access_token",
"token_expiry_date": str(ab_datetime_now() - timedelta(hours=1)),
}
}

def mock_refresh_access_token(self):
nonlocal refresh_call_count
with refresh_call_lock:
refresh_call_count += 1
time.sleep(0.1)
return (
"new_access_token",
ab_datetime_now() + timedelta(hours=1),
"new_refresh_token",
)

mocker.patch.object(
SingleUseRefreshTokenOauth2Authenticator,
"refresh_access_token",
mock_refresh_access_token,
)

mocker.patch.object(
SingleUseRefreshTokenOauth2Authenticator,
"_emit_control_message",
lambda self: None,
)

oauth = SingleUseRefreshTokenOauth2Authenticator(
connector_config=connector_config,
token_refresh_endpoint="https://refresh_endpoint.com",
)

results = []
errors = []

def get_token():
try:
token = oauth.get_access_token()
results.append(token)
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=get_token) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

assert len(errors) == 0, f"Unexpected errors: {errors}"
assert len(results) == 5
assert all(token == "new_access_token" for token in results)
assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}"