Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 36a1b43

Browse files
SameerMesiah97Ratasa143
authored andcommitted
Add best-effort cleanup to EmrCreateJobFlowOperator on post-creation failure (apache#61010)
1 parent f371ced commit 36a1b43

2 files changed

Lines changed: 151 additions & 45 deletions

File tree

providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from typing import TYPE_CHECKING, Any
2525
from uuid import uuid4
2626

27+
from botocore.exceptions import WaiterError
28+
2729
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
2830
from airflow.providers.amazon.aws.links.emr import (
2931
EmrClusterLink,
@@ -665,6 +667,9 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
665667
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
666668
This implies waiting for completion. This mode requires aiobotocore module to be installed.
667669
(default: False)
670+
:param terminate_job_flow_on_failure: If True, attempts best-effort termination of the EMR job flow
671+
when a failure occurs after the job flow has been created. Cleanup failures do not mask the
672+
original exception. (default: True)
668673
"""
669674

670675
aws_hook_class = EmrHook
@@ -691,6 +696,7 @@ def __init__(
691696
waiter_max_attempts: int | None = None,
692697
waiter_delay: int | None = None,
693698
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
699+
terminate_job_flow_on_failure: bool = True,
694700
**kwargs: Any,
695701
):
696702
super().__init__(**kwargs)
@@ -699,6 +705,7 @@ def __init__(
699705
self.waiter_max_attempts = waiter_max_attempts or 60
700706
self.waiter_delay = waiter_delay or 60
701707
self.deferrable = deferrable
708+
self.terminate_job_flow_on_failure = terminate_job_flow_on_failure
702709
self.wait_policy = wait_policy
703710

704711
# Backwards-compatible default: if the user requested waiting for
@@ -746,58 +753,81 @@ def execute(self, context: Context) -> str | None:
746753

747754
self._job_flow_id = response["JobFlowId"]
748755
self.log.info("Job flow with id %s created", self._job_flow_id)
749-
EmrClusterLink.persist(
750-
context=context,
751-
operator=self,
752-
region_name=self.hook.conn_region_name,
753-
aws_partition=self.hook.conn_partition,
754-
job_flow_id=self._job_flow_id,
755-
)
756-
if self._job_flow_id:
757-
EmrLogsLink.persist(
756+
try:
757+
EmrClusterLink.persist(
758758
context=context,
759759
operator=self,
760760
region_name=self.hook.conn_region_name,
761761
aws_partition=self.hook.conn_partition,
762762
job_flow_id=self._job_flow_id,
763-
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
764763
)
765-
if self.wait_for_completion:
766-
# Determine which waiter to use. Prefer explicit wait_policy when provided,
767-
# otherwise default to WAIT_FOR_COMPLETION.
768-
wp = self.wait_policy
769-
if wp is not None:
770-
waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
771-
else:
772-
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
773-
774-
if self.deferrable:
775-
# Pass the selected waiter_name to the trigger so deferrable mode waits
776-
# according to the requested policy as well.
777-
self.defer(
778-
trigger=EmrCreateJobFlowTrigger(
779-
job_flow_id=self._job_flow_id,
780-
aws_conn_id=self.aws_conn_id,
781-
waiter_delay=self.waiter_delay,
782-
waiter_max_attempts=self.waiter_max_attempts,
783-
waiter_name=waiter_name,
784-
),
785-
method_name="execute_complete",
786-
# timeout is set to ensure that if a trigger dies, the timeout does not restart
787-
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
788-
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
789-
)
790-
else:
791-
self.hook.get_waiter(waiter_name).wait(
792-
ClusterId=self._job_flow_id,
793-
WaiterConfig=prune_dict(
794-
{
795-
"Delay": self.waiter_delay,
796-
"MaxAttempts": self.waiter_max_attempts,
797-
}
798-
),
764+
if self._job_flow_id:
765+
EmrLogsLink.persist(
766+
context=context,
767+
operator=self,
768+
region_name=self.hook.conn_region_name,
769+
aws_partition=self.hook.conn_partition,
770+
job_flow_id=self._job_flow_id,
771+
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
799772
)
800-
return self._job_flow_id
773+
if self.wait_for_completion:
774+
# Determine which waiter to use. Prefer explicit wait_policy when provided,
775+
# otherwise default to WAIT_FOR_COMPLETION.
776+
wp = self.wait_policy
777+
if wp is not None:
778+
waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
779+
else:
780+
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
781+
782+
if self.deferrable:
783+
# Pass the selected waiter_name to the trigger so deferrable mode waits
784+
# according to the requested policy as well.
785+
self.defer(
786+
trigger=EmrCreateJobFlowTrigger(
787+
job_flow_id=self._job_flow_id,
788+
aws_conn_id=self.aws_conn_id,
789+
waiter_delay=self.waiter_delay,
790+
waiter_max_attempts=self.waiter_max_attempts,
791+
waiter_name=waiter_name,
792+
),
793+
method_name="execute_complete",
794+
# timeout is set to ensure that if a trigger dies, the timeout does not restart
795+
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
796+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
797+
)
798+
else:
799+
self.hook.get_waiter(waiter_name).wait(
800+
ClusterId=self._job_flow_id,
801+
WaiterConfig=prune_dict(
802+
{
803+
"Delay": self.waiter_delay,
804+
"MaxAttempts": self.waiter_max_attempts,
805+
}
806+
),
807+
)
808+
return self._job_flow_id
809+
810+
# Best-effort cleanup when post-creation steps fail (e.g. IAM/permission errors).
811+
except WaiterError:
812+
if self._job_flow_id:
813+
if self.terminate_job_flow_on_failure:
814+
self.log.warning(
815+
"Task failed after creating EMR job flow %s.",
816+
self._job_flow_id,
817+
)
818+
try:
819+
self.log.info(
820+
"Attempting termination of EMR job flow %s.",
821+
self._job_flow_id,
822+
)
823+
824+
self.hook.conn.terminate_job_flows(JobFlowIds=[self._job_flow_id])
825+
except Exception:
826+
self.log.exception(
827+
"Failed to terminate EMR job flow %s after task failure.",
828+
self._job_flow_id,
829+
)
830+
raise
801831

802832
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
803833
validated_event = validate_execute_complete_event(event)

providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from unittest.mock import MagicMock, patch
2424

2525
import pytest
26+
from botocore.exceptions import ClientError, WaiterError
2627
from botocore.waiter import Waiter
2728
from jinja2 import StrictUndefined
2829

@@ -231,6 +232,7 @@ def test_create_job_flow_deferrable(self, mocked_hook_client):
231232

232233
self.operator.deferrable = True
233234
self.operator.wait_for_completion = True
235+
234236
with pytest.raises(TaskDeferred) as exc:
235237
self.operator.execute(self.mock_context)
236238

@@ -281,3 +283,77 @@ def test_specify_only_wait_policy(self):
281283
)
282284
assert getattr(op, "wait_policy") == WaitPolicy.WAIT_FOR_STEPS_COMPLETION
283285
assert op.wait_for_completion is True
286+
287+
def test_cleanup_on_post_create_failure(self, mocked_hook_client):
288+
"""
289+
Ensure that if the job flow is created successfully but a subsequent
290+
post-create step fails (e.g. waiter / DescribeCluster),
291+
the operator attempts best-effort cleanup.
292+
"""
293+
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
294+
295+
self.operator.wait_for_completion = True
296+
self.operator.terminate_job_flow_on_failure = True
297+
298+
waiter_error = WaiterError(
299+
"ClusterRunning",
300+
"You are not authorized to perform this operation",
301+
{},
302+
)
303+
304+
with (
305+
patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
306+
patch.object(self.operator.hook.conn, "terminate_job_flows") as mock_terminate,
307+
):
308+
mock_get_waiter.return_value.wait.side_effect = waiter_error
309+
310+
with pytest.raises(WaiterError) as exc:
311+
self.operator.execute(self.mock_context)
312+
313+
# Original exception must be propagated unchanged
314+
assert exc.value is waiter_error
315+
316+
# Cleanup must be attempted
317+
mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])
318+
319+
def test_cleanup_failure_does_not_mask_original_exception(self, mocked_hook_client):
320+
"""
321+
Ensure that failure during cleanup does not override
322+
the original post-create exception.
323+
"""
324+
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
325+
326+
self.operator.wait_for_completion = True
327+
self.operator.terminate_job_flow_on_failure = True
328+
329+
waiter_error = WaiterError(
330+
"ClusterRunning",
331+
"You are not authorized to perform this operation",
332+
{},
333+
)
334+
335+
cleanup_error = ClientError(
336+
error_response={
337+
"Error": {
338+
"Code": "UnauthorizedOperation",
339+
"Message": "You are not authorized to perform this operation",
340+
}
341+
},
342+
operation_name="TerminateJobFlows",
343+
)
344+
345+
with (
346+
patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
347+
patch.object(self.operator.hook.conn, "terminate_job_flows") as mock_terminate,
348+
):
349+
mock_get_waiter.return_value.wait.side_effect = waiter_error
350+
mock_terminate.side_effect = cleanup_error
351+
352+
with pytest.raises(WaiterError) as exc:
353+
self.operator.execute(self.mock_context)
354+
355+
# Original exception must be preserved
356+
assert exc.value is waiter_error
357+
358+
# Cleanup attempted despite failure
359+
mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])

0 commit comments

Comments
 (0)