2424from typing import TYPE_CHECKING , Any
2525from uuid import uuid4
2626
27+ from botocore .exceptions import WaiterError
28+
2729from airflow .providers .amazon .aws .hooks .emr import EmrContainerHook , EmrHook , EmrServerlessHook
2830from airflow .providers .amazon .aws .links .emr import (
2931 EmrClusterLink ,
@@ -665,6 +667,9 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
665667 :param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
666668 This implies waiting for completion. This mode requires aiobotocore module to be installed.
667669 (default: False)
670+ :param terminate_job_flow_on_failure: If True, attempts best-effort termination of the EMR job flow
671+ when a failure occurs after the job flow has been created. Cleanup failures do not mask the
672+ original exception. (default: True)
668673 """
669674
670675 aws_hook_class = EmrHook
@@ -691,6 +696,7 @@ def __init__(
691696 waiter_max_attempts : int | None = None ,
692697 waiter_delay : int | None = None ,
693698 deferrable : bool = conf .getboolean ("operators" , "default_deferrable" , fallback = False ),
699+ terminate_job_flow_on_failure : bool = True ,
694700 ** kwargs : Any ,
695701 ):
696702 super ().__init__ (** kwargs )
@@ -699,6 +705,7 @@ def __init__(
699705 self .waiter_max_attempts = waiter_max_attempts or 60
700706 self .waiter_delay = waiter_delay or 60
701707 self .deferrable = deferrable
708+ self .terminate_job_flow_on_failure = terminate_job_flow_on_failure
702709 self .wait_policy = wait_policy
703710
704711 # Backwards-compatible default: if the user requested waiting for
@@ -746,58 +753,81 @@ def execute(self, context: Context) -> str | None:
746753
747754 self ._job_flow_id = response ["JobFlowId" ]
748755 self .log .info ("Job flow with id %s created" , self ._job_flow_id )
749- EmrClusterLink .persist (
750- context = context ,
751- operator = self ,
752- region_name = self .hook .conn_region_name ,
753- aws_partition = self .hook .conn_partition ,
754- job_flow_id = self ._job_flow_id ,
755- )
756- if self ._job_flow_id :
757- EmrLogsLink .persist (
756+ try :
757+ EmrClusterLink .persist (
758758 context = context ,
759759 operator = self ,
760760 region_name = self .hook .conn_region_name ,
761761 aws_partition = self .hook .conn_partition ,
762762 job_flow_id = self ._job_flow_id ,
763- log_uri = get_log_uri (emr_client = self .hook .conn , job_flow_id = self ._job_flow_id ),
764763 )
765- if self .wait_for_completion :
766- # Determine which waiter to use. Prefer explicit wait_policy when provided,
767- # otherwise default to WAIT_FOR_COMPLETION.
768- wp = self .wait_policy
769- if wp is not None :
770- waiter_name = WAITER_POLICY_NAME_MAPPING [wp ]
771- else :
772- waiter_name = WAITER_POLICY_NAME_MAPPING [WaitPolicy .WAIT_FOR_COMPLETION ]
773-
774- if self .deferrable :
775- # Pass the selected waiter_name to the trigger so deferrable mode waits
776- # according to the requested policy as well.
777- self .defer (
778- trigger = EmrCreateJobFlowTrigger (
779- job_flow_id = self ._job_flow_id ,
780- aws_conn_id = self .aws_conn_id ,
781- waiter_delay = self .waiter_delay ,
782- waiter_max_attempts = self .waiter_max_attempts ,
783- waiter_name = waiter_name ,
784- ),
785- method_name = "execute_complete" ,
786- # timeout is set to ensure that if a trigger dies, the timeout does not restart
787- # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
788- timeout = timedelta (seconds = self .waiter_max_attempts * self .waiter_delay + 60 ),
789- )
790- else :
791- self .hook .get_waiter (waiter_name ).wait (
792- ClusterId = self ._job_flow_id ,
793- WaiterConfig = prune_dict (
794- {
795- "Delay" : self .waiter_delay ,
796- "MaxAttempts" : self .waiter_max_attempts ,
797- }
798- ),
764+ if self ._job_flow_id :
765+ EmrLogsLink .persist (
766+ context = context ,
767+ operator = self ,
768+ region_name = self .hook .conn_region_name ,
769+ aws_partition = self .hook .conn_partition ,
770+ job_flow_id = self ._job_flow_id ,
771+ log_uri = get_log_uri (emr_client = self .hook .conn , job_flow_id = self ._job_flow_id ),
799772 )
800- return self ._job_flow_id
773+ if self .wait_for_completion :
774+ # Determine which waiter to use. Prefer explicit wait_policy when provided,
775+ # otherwise default to WAIT_FOR_COMPLETION.
776+ wp = self .wait_policy
777+ if wp is not None :
778+ waiter_name = WAITER_POLICY_NAME_MAPPING [wp ]
779+ else :
780+ waiter_name = WAITER_POLICY_NAME_MAPPING [WaitPolicy .WAIT_FOR_COMPLETION ]
781+
782+ if self .deferrable :
783+ # Pass the selected waiter_name to the trigger so deferrable mode waits
784+ # according to the requested policy as well.
785+ self .defer (
786+ trigger = EmrCreateJobFlowTrigger (
787+ job_flow_id = self ._job_flow_id ,
788+ aws_conn_id = self .aws_conn_id ,
789+ waiter_delay = self .waiter_delay ,
790+ waiter_max_attempts = self .waiter_max_attempts ,
791+ waiter_name = waiter_name ,
792+ ),
793+ method_name = "execute_complete" ,
794+ # timeout is set to ensure that if a trigger dies, the timeout does not restart
795+ # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
796+ timeout = timedelta (seconds = self .waiter_max_attempts * self .waiter_delay + 60 ),
797+ )
798+ else :
799+ self .hook .get_waiter (waiter_name ).wait (
800+ ClusterId = self ._job_flow_id ,
801+ WaiterConfig = prune_dict (
802+ {
803+ "Delay" : self .waiter_delay ,
804+ "MaxAttempts" : self .waiter_max_attempts ,
805+ }
806+ ),
807+ )
808+ return self ._job_flow_id
809+
810+ # Best-effort cleanup when post-creation steps fail (e.g. IAM/permission errors).
811+ except WaiterError :
812+ if self ._job_flow_id :
813+ if self .terminate_job_flow_on_failure :
814+ self .log .warning (
815+ "Task failed after creating EMR job flow %s." ,
816+ self ._job_flow_id ,
817+ )
818+ try :
819+ self .log .info (
820+ "Attempting termination of EMR job flow %s." ,
821+ self ._job_flow_id ,
822+ )
823+
824+ self .hook .conn .terminate_job_flows (JobFlowIds = [self ._job_flow_id ])
825+ except Exception :
826+ self .log .exception (
827+ "Failed to terminate EMR job flow %s after task failure." ,
828+ self ._job_flow_id ,
829+ )
830+ raise
801831
802832 def execute_complete (self , context : Context , event : dict [str , Any ] | None = None ) -> str :
803833 validated_event = validate_execute_complete_event (event )
0 commit comments