From f336612ab5cdea46823188b0d6a5b41242e3c802 Mon Sep 17 00:00:00 2001 From: Chris Gillum Date: Fri, 26 May 2023 05:17:53 +0000 Subject: [PATCH] Starter code for retry policies --- durabletask/task.py | 99 ++++++++++++++++++++++++---- durabletask/worker.py | 47 +++++++------ tests/test_orchestration_e2e.py | 36 ++++++++++ tests/test_orchestration_executor.py | 92 ++++++++++++++++++++++++++ 4 files changed, 241 insertions(+), 33 deletions(-) diff --git a/durabletask/task.py b/durabletask/task.py index 3acdcbe..d94d8f7 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -6,7 +6,8 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Generic, List, TypeVar, Union +from typing import (Any, Callable, Generator, Generic, List, Optional, TypeVar, + Union) import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -87,17 +88,18 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task: @abstractmethod def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, - input: Union[TInput, None] = None) -> Task[TOutput]: + input: Optional[TInput] = None, + retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]: """Schedule an activity for execution. Parameters ---------- activity: Union[Activity[TInput, TOutput], str] A reference to the activity function to call. - input: Union[TInput, None] + input: Optional[TInput] The JSON-serializable input (or None) to pass to the activity. - return_type: task.Task[TOutput] - The JSON-serializable output type to expect from the activity result. + retry_policy: Optional[RetryPolicy] + The retry policy to use for this activity call. Returns ------- @@ -108,19 +110,22 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, @abstractmethod def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, - input: Union[TInput, None] = None, - instance_id: Union[str, None] = None) -> Task[TOutput]: + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]: """Schedule sub-orchestrator function for execution. Parameters ---------- orchestrator: Orchestrator[TInput, TOutput] A reference to the orchestrator function to call. - input: Union[TInput, None] + input: Optional[TInput] The optional JSON-serializable input to pass to the orchestrator function. - instance_id: Union[str, None] + instance_id: Optional[str] A unique ID to use for the sub-orchestration instance. If not specified, a random UUID will be used. + retry_policy: Optional[RetryPolicy] + The retry policy to use for this sub-orchestrator call. Returns ------- @@ -162,7 +167,7 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: class FailureDetails: - def __init__(self, message: str, error_type: str, stack_trace: Union[str, None]): + def __init__(self, message: str, error_type: str, stack_trace: Optional[str]): self._message = message self._error_type = error_type self._stack_trace = stack_trace @@ -176,7 +181,7 @@ def error_type(self) -> str: return self._error_type @property - def stack_trace(self) -> Union[str, None]: + def stack_trace(self) -> Optional[str]: return self._stack_trace @@ -206,8 +211,8 @@ class OrchestrationStateError(Exception): class Task(ABC, Generic[T]): """Abstract base class for asynchronous tasks in a durable orchestration.""" _result: T - _exception: Union[TaskFailedError, None] - _parent: Union[CompositeTask[T], None] + _exception: Optional[TaskFailedError] + _parent: Optional[CompositeTask[T]] def __init__(self) -> None: super().__init__() @@ -376,6 +381,74 @@ def task_id(self) -> int: Activity = Callable[[ActivityContext, TInput], TOutput] +class RetryPolicy: + """Represents the retry policy for an orchestration or activity function.""" + + def __init__(self, *, + first_retry_interval: timedelta, + max_number_of_attempts: int, + backoff_coefficient: Optional[float] = 1.0, + max_retry_interval: Optional[timedelta] = None, + retry_timeout: Optional[timedelta] = None): + """Creates a new RetryPolicy instance. + + Parameters + ---------- + first_retry_interval : timedelta + The retry interval to use for the first retry attempt. + max_number_of_attempts : int + The maximum number of retry attempts. + backoff_coefficient : Optional[float] + The backoff coefficient to use for calculating the next retry interval. + max_retry_interval : Optional[timedelta] + The maximum retry interval to use for any retry attempt. + retry_timeout : Optional[timedelta] + The maximum amount of time to spend retrying the operation. + """ + # validate inputs + if first_retry_interval < timedelta(seconds=0): + raise ValueError('first_retry_interval must be >= 0') + if max_number_of_attempts < 1: + raise ValueError('max_number_of_attempts must be >= 1') + if backoff_coefficient is not None and backoff_coefficient < 1: + raise ValueError('backoff_coefficient must be >= 1') + if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0): + raise ValueError('max_retry_interval must be >= 0') + if retry_timeout is not None and retry_timeout < timedelta(seconds=0): + raise ValueError('retry_timeout must be >= 0') + + self._first_retry_interval = first_retry_interval + self._max_number_of_attempts = max_number_of_attempts + self._backoff_coefficient = backoff_coefficient + self._max_retry_interval = max_retry_interval + self._retry_timeout = retry_timeout + + @property + def first_retry_interval(self) -> timedelta: + """The retry interval to use for the first retry attempt.""" + return self._first_retry_interval + + @property + def max_number_of_attempts(self) -> int: + """The maximum number of retry attempts.""" + return self._max_number_of_attempts + + @property + def backoff_coefficient(self) -> Optional[float]: + """The backoff coefficient to use for calculating the next retry interval.""" + return self._backoff_coefficient + + @property + def max_retry_interval(self) -> Optional[timedelta]: + """The maximum retry interval to use for any retry attempt.""" + return self._max_retry_interval + + @property + def retry_timeout(self) -> Optional[timedelta]: + """The maximum amount of time to spend retrying the operation.""" + return self._retry_timeout + + def get_name(fn: Callable) -> str: """Returns the name of the provided function""" name = fn.__name__ diff --git a/durabletask/worker.py b/durabletask/worker.py index fd14b28..fffa029 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -6,7 +6,8 @@ from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType -from typing import Any, Dict, Generator, List, Sequence, TypeVar, Union +from typing import (Any, Dict, Generator, List, Optional, Sequence, TypeVar, + Union) import grpc from google.protobuf import empty_pb2 @@ -47,7 +48,7 @@ def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None: self.orchestrators[name] = fn - def get_orchestrator(self, name: str) -> Union[task.Orchestrator, None]: + def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]: return self.orchestrators.get(name) def add_activity(self, fn: task.Activity) -> str: @@ -66,7 +67,7 @@ def add_named_activity(self, name: str, fn: task.Activity) -> None: self.activities[name] = fn - def get_activity(self, name: str) -> Union[task.Activity, None]: + def get_activity(self, name: str) -> Optional[task.Activity]: return self.activities.get(name) @@ -81,17 +82,16 @@ class ActivityNotRegisteredError(ValueError): class TaskHubGrpcWorker: - _response_stream: Union[grpc.Future, None] + _response_stream: Optional[grpc.Future] = None def __init__(self, *, - host_address: Union[str, None] = None, + host_address: Optional[str] = None, log_handler=None, - log_formatter: Union[logging.Formatter, None] = None): + log_formatter: Optional[logging.Formatter] = None): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() - self._response_stream = None self._is_running = False def __enter__(self): @@ -220,8 +220,8 @@ def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarS class _RuntimeOrchestrationContext(task.OrchestrationContext): - _generator: Union[Generator[task.Task, Any, Any], None] - _previous_task: Union[task.Task, None] + _generator: Optional[Generator[task.Task, Any, Any]] + _previous_task: Optional[task.Task] def __init__(self, instance_id: str): self._generator = None @@ -233,10 +233,10 @@ def __init__(self, instance_id: str): self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id - self._completion_status: Union[pb.OrchestrationStatus, None] = None + self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: Dict[str, List[Any]] = {} self._pending_events: Dict[str, List[task.CompletableTask]] = {} - self._new_input: Union[Any, None] = None + self._new_input: Optional[Any] = None self._save_events = False def run(self, generator: Generator[task.Task, Any, Any]): @@ -281,7 +281,7 @@ def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_en self._pending_actions.clear() # Cancel any pending actions self._result = result - result_json: Union[str, None] = None + result_json: Optional[str] = None if result is not None: result_json = result if is_result_encoded else shared.to_json(result) action = ph.new_complete_orchestration_action( @@ -314,7 +314,7 @@ def set_continued_as_new(self, new_input: Any, save_events: bool): def get_actions(self) -> List[pb.OrchestratorAction]: if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: # When continuing-as-new, we only return a single completion action. - carryover_events: Union[List[pb.HistoryEvent], None] = None + carryover_events: Optional[List[pb.HistoryEvent]] = None if self._save_events: carryover_events = [] # We need to save the current set of pending events so that they can be @@ -365,7 +365,8 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return timer_task def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *, - input: Union[TInput, None] = None) -> task.Task[TOutput]: + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]: id = self.next_sequence_number() name = activity if isinstance(activity, str) else task.get_name(activity) encoded_input = shared.to_json(input) if input else None @@ -377,8 +378,9 @@ def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *, return activity_task def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput], *, - input: Union[TInput, None] = None, - instance_id: Union[str, None] = None) -> task.Task[TOutput]: + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]: id = self.next_sequence_number() name = task.get_name(orchestrator) if instance_id is None: @@ -422,12 +424,11 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None: class _OrchestrationExecutor: - _generator: Union[task.Orchestrator, None] + _generator: Optional[task.Orchestrator] = None def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger - self._generator = None self._is_suspended = False self._suspended_events: List[pb.HistoryEvent] = [] @@ -558,6 +559,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.") return + # TODO: If there's a retry policy, we need to check if we should retry. + # Retries involve 1) scheduling a retry timer, and 2) scheduling a task to execute the activity function again. + # Only if we exhaust the retry policy do we fail the activity task (see below). activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", event.taskFailed.failureDetails) @@ -602,6 +606,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}.") return + # TODO: If there's a retry policy, we need to check if we should retry. + # Retries involve 1) scheduling a retry timer, and 2) scheduling a task to execute the activity function again. + # Only if the retry policy is exhausted should we fail the sub-orchestration task (see below). sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", failedEvent.failureDetails) @@ -612,7 +619,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id} Event raised: {event_name}") task_list = ctx._pending_events.get(event_name, None) - decoded_result: Union[Any, None] = None + decoded_result: Optional[Any] = None if task_list: event_task = task_list.pop(0) if not ph.is_empty(event.eventRaised.input): @@ -661,7 +668,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger - def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Union[str, None]) -> Union[str, None]: + def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Optional[str]) -> Optional[str]: """Executes an activity function and returns the serialized result, if any.""" self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...") fn = self._registry.get_activity(name) diff --git a/tests/test_orchestration_e2e.py b/tests/test_orchestration_e2e.py index edafded..1447918 100644 --- a/tests/test_orchestration_e2e.py +++ b/tests/test_orchestration_e2e.py @@ -268,3 +268,39 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert state.serialized_output == json.dumps(all_results) assert state.serialized_input == json.dumps(4) assert all_results == [1, 2, 3, 4, 5] + + +def test_retry_policies(): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(milliseconds=1), + max_number_of_attempts=3, + ) + + def parent_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator(child_orchestrator, retry_policy=retry_policy) + + def child_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(throw_activity, retry_policy=retry_policy) + + def throw_activity(ctx: task.ActivityContext, _): + raise RuntimeError("Kah-BOOOOM!!!") + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(parent_orchestrator) + w.add_orchestrator(child_orchestrator) + w.add_activity(throw_activity) + w.start() + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(parent_orchestrator) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "RuntimeError" + assert state.failure_details.message == "Kah-BOOOOM!!!" + assert state.failure_details.stack_trace is not None + assert state.failure_details.stack_trace.startswith("RuntimeError: Kah-BOOOOM!!!\n") + + # TODO: Verify that the throw_activity was called 9 times and the child_orchestrator was called 3 times diff --git a/tests/test_orchestration_executor.py b/tests/test_orchestration_executor.py index 83fee74..9f716b2 100644 --- a/tests/test_orchestration_executor.py +++ b/tests/test_orchestration_executor.py @@ -228,6 +228,98 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): assert user_code_statement in complete_action.failureDetails.stackTrace.value +def test_activity_retry_policies(): + """Tests the retry policy logic for activity tasks""" + + def dummy_activity(ctx, _): + raise ValueError("Kah-BOOOOM!!!") + + def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): + result = yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=30)), + input=orchestrator_input) + return result + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + # Simulate the task failing for the first time and confirm that a timer is scheduled for 1 second in the future + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] + current_timestamp = datetime.utcnow() + expected_fire_at = current_timestamp + timedelta(seconds=1) + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(1, Exception("Kah-BOOOOM!!!"))] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + assert actions[0].createTimer.fireAt.ToDatetime() == expected_fire_at + assert actions[0].id == 2 + + # Simulate the timer firing at the expected time and confirm that another activity task is scheduled + current_timestamp = expected_fire_at + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_timer_fired_event(2, current_timestamp)] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 1 + assert actions[0].HasField("taskScheduled") + assert actions[0].id == 3 + + # Simulate the task failing for the second time and confirm that a timer is scheduled for 2 seconds in the future + old_events = old_events + new_events + current_timestamp = current_timestamp + timedelta(seconds=1) + expected_fire_at = current_timestamp + timedelta(seconds=2) + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(3, Exception("Kah-BOOOOM!!!"))] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + assert actions[0].createTimer.fireAt.ToDatetime() == expected_fire_at + assert actions[0].id == 4 + + # Simulate the timer firing at the expected time and confirm that another activity task is scheduled + current_timestamp = expected_fire_at + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_timer_fired_event(4, current_timestamp)] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 1 + assert actions[0].HasField("taskScheduled") + assert actions[0].id == 5 + + # Simulate the task failing for a third time and confirm that a timer is scheduled for 4 seconds in the future + expected_fire_at = current_timestamp + timedelta(seconds=4) + old_events = old_events + new_events + new_events = [ + helpers.new_orchestrator_started_event(current_timestamp), + helpers.new_task_failed_event(5, Exception("Kah-BOOOOM!!!"))] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + assert actions[0].createTimer.fireAt.ToDatetime() == expected_fire_at + assert actions[0].id == 6 + + # TODO: Keep going, and confirm the behavior of max_retry_interval and retry_timeout + + def test_nondeterminism_expected_timer(): """Tests the non-determinism detection logic when call_timer is expected but some other method (call_activity) is called instead""" def dummy_activity(ctx, _):