diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 48c5bba73..1728191ba 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -3,6 +3,7 @@ # import logging +import threading from abc import abstractmethod from datetime import timedelta from json import JSONDecodeError @@ -41,6 +42,13 @@ class AbstractOauth2Authenticator(AuthBase): _NO_STREAM_NAME = None + # Class-level lock to prevent concurrent token refresh across multiple authenticator instances. + # This is necessary because multiple streams may share the same OAuth credentials (refresh token) + # through the connector config. Without this lock, concurrent refresh attempts can cause race + # conditions where one stream successfully refreshes the token while others fail because the + # refresh token has been invalidated (especially for single-use refresh tokens). + _token_refresh_lock: threading.Lock = threading.Lock() + def __init__( self, refresh_token_error_status_codes: Tuple[int, ...] = (), @@ -86,9 +94,19 @@ def get_auth_header(self) -> Mapping[str, Any]: return {"Authorization": f"Bearer {token}"} def get_access_token(self) -> str: - """Returns the access token""" + """ + Returns the access token. + + This method uses double-checked locking to ensure thread-safe token refresh. + When multiple threads (streams) detect an expired token simultaneously, only one + will perform the refresh while others wait. After acquiring the lock, the token + expiry is re-checked to avoid redundant refresh attempts. + """ if self.token_has_expired(): - self.refresh_and_set_access_token() + with self._token_refresh_lock: + # Double-check after acquiring lock - another thread may have already refreshed + if self.token_has_expired(): + self.refresh_and_set_access_token() return self.access_token diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index cb64eb3e3..a32845a6e 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -319,13 +319,21 @@ def token_has_expired(self) -> bool: def get_access_token(self) -> str: """Retrieve new access and refresh token if the access token has expired. + This method uses double-checked locking to ensure thread-safe token refresh. + This is especially critical for single-use refresh tokens where concurrent + refresh attempts would cause failures as the refresh token is invalidated + after first use. + The new refresh token is persisted with the set_refresh_token function. Returns: str: The current access_token, updated if it was previously expired. """ if self.token_has_expired(): - self.refresh_and_set_access_token() + with self._token_refresh_lock: + # Double-check after acquiring lock - another thread may have already refreshed + if self.token_has_expired(): + self.refresh_and_set_access_token() return self.access_token def refresh_and_set_access_token(self) -> None: diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index dbfc0ac86..71183b5aa 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -4,6 +4,8 @@ import json import logging +import threading +import time from datetime import timedelta from typing import Optional, Union from unittest.mock import Mock @@ -785,3 +787,132 @@ def mock_request(method, url, data, headers): raise Exception( f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" ) + + +class TestConcurrentTokenRefresh: + """ + Test class for verifying thread-safe token refresh behavior. + + These tests ensure that when multiple threads (streams) attempt to refresh + an expired token simultaneously, only one refresh actually occurs and + others wait and use the refreshed token. + """ + + def test_concurrent_token_refresh_only_refreshes_once(self, mocker): + """ + When multiple threads detect an expired token and try to refresh simultaneously, + only one thread should actually perform the refresh. Others should wait and + use the newly refreshed token. + """ + refresh_call_count = 0 + refresh_call_lock = threading.Lock() + + def mock_refresh_access_token(self): + nonlocal refresh_call_count + with refresh_call_lock: + refresh_call_count += 1 + time.sleep(0.1) + return ("new_access_token", ab_datetime_now() + timedelta(hours=1)) + + mocker.patch.object( + Oauth2Authenticator, + "refresh_access_token", + mock_refresh_access_token, + ) + + oauth = Oauth2Authenticator( + token_refresh_endpoint="https://refresh_endpoint.com", + client_id="client_id", + client_secret="client_secret", + refresh_token="refresh_token", + token_expiry_date=ab_datetime_now() - timedelta(hours=1), + ) + + results = [] + errors = [] + + def get_token(): + try: + token = oauth.get_access_token() + results.append(token) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=get_token) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Unexpected errors: {errors}" + assert len(results) == 5 + assert all(token == "new_access_token" for token in results) + assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}" + + def test_single_use_refresh_token_concurrent_refresh_only_refreshes_once(self, mocker): + """ + For SingleUseRefreshTokenOauth2Authenticator, concurrent refresh attempts + should also only result in one actual refresh to prevent invalidating + the single-use refresh token. + """ + refresh_call_count = 0 + refresh_call_lock = threading.Lock() + + connector_config = { + "credentials": { + "client_id": "client_id", + "client_secret": "client_secret", + "refresh_token": "refresh_token", + "access_token": "old_access_token", + "token_expiry_date": str(ab_datetime_now() - timedelta(hours=1)), + } + } + + def mock_refresh_access_token(self): + nonlocal refresh_call_count + with refresh_call_lock: + refresh_call_count += 1 + time.sleep(0.1) + return ( + "new_access_token", + ab_datetime_now() + timedelta(hours=1), + "new_refresh_token", + ) + + mocker.patch.object( + SingleUseRefreshTokenOauth2Authenticator, + "refresh_access_token", + mock_refresh_access_token, + ) + + mocker.patch.object( + SingleUseRefreshTokenOauth2Authenticator, + "_emit_control_message", + lambda self: None, + ) + + oauth = SingleUseRefreshTokenOauth2Authenticator( + connector_config=connector_config, + token_refresh_endpoint="https://refresh_endpoint.com", + ) + + results = [] + errors = [] + + def get_token(): + try: + token = oauth.get_access_token() + results.append(token) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=get_token) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Unexpected errors: {errors}" + assert len(results) == 5 + assert all(token == "new_access_token" for token in results) + assert refresh_call_count == 1, f"Expected 1 refresh call, got {refresh_call_count}"