diff --git a/providers/alibaba/pyproject.toml b/providers/alibaba/pyproject.toml index ce8a445714a21..27e99ec6d6b49 100644 --- a/providers/alibaba/pyproject.toml +++ b/providers/alibaba/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version "oss2>=2.14.0", "alibabacloud_adb20211201>=1.0.0", "alibabacloud_tea_openapi>=0.3.7", diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py index 90003571066ba..723ed41b5a75c 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py @@ -21,8 +21,7 @@ from airflow.providers.common.compat.sdk import BaseOperatorLink, XCom if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import BaseOperator + from airflow.providers.common.compat.sdk import BaseOperator, TaskInstanceKey from airflow.sdk import Context diff --git a/providers/amazon/pyproject.toml b/providers/amazon/pyproject.toml index 6765fea349a43..accb12f8241a3 100644 --- a/providers/amazon/pyproject.toml +++ b/providers/amazon/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version "apache-airflow-providers-common-sql>=1.27.0", "apache-airflow-providers-http", # We should update minimum version of boto3 and here regularly to avoid `pip` backtracking with the number diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py index 8b08d395b7782..7dacb818e7d7f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from airflow.models import BaseOperator - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.sdk import Context diff --git a/providers/celery/pyproject.toml b/providers/celery/pyproject.toml index 4d8c6c838eed2..5b4ea492b8086 100644 --- a/providers/celery/pyproject.toml +++ b/providers/celery/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version # The Celery is known to introduce problems when upgraded to a MAJOR version. Airflow Core # Uses Celery for CeleryExecutor, and we also know that Kubernetes Python client follows SemVer # (https://docs.celeryq.dev/en/stable/contributing.html?highlight=semver#versions). diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index 17bfe167f832b..746de183162b9 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -43,8 +43,7 @@ from airflow.executors import workloads from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance -from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout +from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, TaskInstanceKey from airflow.providers.standard.operators.bash import BashOperator from airflow.utils.state import State diff --git a/providers/cncf/kubernetes/pyproject.toml b/providers/cncf/kubernetes/pyproject.toml index 5c660cac1414f..0c43340ec48c6 100644 --- a/providers/cncf/kubernetes/pyproject.toml +++ b/providers/cncf/kubernetes/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "aiofiles>=23.2.0", "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.10.1", + "apache-airflow-providers-common-compat>=1.10.1", # use next version "asgiref>=3.5.2", # TODO(potiuk): We should bump cryptography to >=46.0.0 when sqlalchemy>=2.0 is required "cryptography>=41.0.0,<46.0.0", diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index c2334b1c9a2e2..4dff47c4b1d78 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -166,7 +166,8 @@ def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey: # Compat: Look up the run_id from the TI table! from airflow.models.dagrun import DagRun - from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.settings import Session logical_date_key = get_logical_date_key() diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py b/providers/common/compat/src/airflow/providers/common/compat/sdk.py index 398b1219e5104..ccb4e715abfb0 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py +++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py @@ -111,6 +111,7 @@ ) from airflow.sdk.execution_time.timeout import timeout as timeout from airflow.sdk.execution_time.xcom import XCom as XCom + from airflow.sdk.types import TaskInstanceKey as TaskInstanceKey from airflow.providers.common.compat._compat_utils import create_module_getattr @@ -185,6 +186,7 @@ # Operator Links & Task Groups # ============================================================================ "BaseOperatorLink": ("airflow.sdk", "airflow.models.baseoperatorlink"), + "TaskInstanceKey": ("airflow.sdk.types", "airflow.models.taskinstancekey"), "TaskGroup": ("airflow.sdk", "airflow.utils.task_group"), # ============================================================================ # Operator Utilities (chain, cross_downstream, etc.) diff --git a/providers/databricks/pyproject.toml b/providers/databricks/pyproject.toml index d9524a0d9b68f..d9eaaa60b7455 100644 --- a/providers/databricks/pyproject.toml +++ b/providers/databricks/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version "apache-airflow-providers-common-sql>=1.27.0", "requests>=2.32.0,<3", "databricks-sql-connector>=4.0.0", diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index fa84e8d459cad..fe6240dad5cc1 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -53,7 +53,7 @@ from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, ) diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index cb1b5c747e9a6..d50acc800e78f 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -298,7 +298,7 @@ def xcom_key(self) -> str: """XCom key where the link is stored during task execution.""" return "databricks_job_run_link" - def get_link( + def get_link( # type: ignore[override] # Signature intentionally kept this way for Airflow 2.x compatibility self, operator: BaseOperator, dttm=None, @@ -374,7 +374,7 @@ class WorkflowJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin): name = "Repair All Failed Tasks" - def get_link( + def get_link( # type: ignore[override] # Signature intentionally kept this way for Airflow 2.x compatibility self, operator, dttm=None, @@ -471,7 +471,7 @@ class WorkflowJobRepairSingleTaskLink(BaseOperatorLink, LoggingMixin): name = "Repair a single task" - def get_link( + def get_link( # type: ignore[override] # Signature intentionally kept this way for Airflow 2.x compatibility self, operator, dttm=None, diff --git a/providers/edge3/src/airflow/providers/edge3/cli/api_client.py b/providers/edge3/src/airflow/providers/edge3/cli/api_client.py index 547e870d8aeb8..e1932d90df031 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/api_client.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/api_client.py @@ -45,7 +45,7 @@ ) if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.providers.edge3.models.edge_worker import EdgeWorkerState from airflow.utils.state import TaskInstanceState diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py index ccf21de848fcd..d322790346522 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py +++ b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py @@ -27,8 +27,7 @@ from sqlalchemy.orm import Mapped from airflow.models.base import Base, StringID -from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.providers.common.compat.sdk import timezone +from airflow.providers.common.compat.sdk import TaskInstanceKey, timezone from airflow.providers.common.compat.sqlalchemy.orm import mapped_column from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.sqlalchemy import UtcDateTime diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py index 7c58b5fdc87ba..c25ff53a93437 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py @@ -26,7 +26,7 @@ from pydantic import BaseModel, Field from airflow.executors.workloads import ExecuteTask # noqa: TCH001 -from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.providers.edge3.models.edge_worker import EdgeWorkerState # noqa: TCH001 diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py index 064808b2b0484..e58e6874fd1a4 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py @@ -28,7 +28,7 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.configuration import conf from airflow.models.taskinstance import TaskInstance -from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.providers.edge3.models.edge_logs import EdgeLogsModel from airflow.providers.edge3.worker_api.auth import jwt_token_authorization_rest from airflow.providers.edge3.worker_api.datamodels import PushLogsBody, WorkerApiDocs diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml index e80d279c3cdbb..c03fafa2ba119 100644 --- a/providers/google/pyproject.toml +++ b/providers/google/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version "apache-airflow-providers-common-sql>=1.27.0", "asgiref>=3.5.2", "dill>=0.2.3", diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py b/providers/google/src/airflow/providers/google/cloud/links/base.py index d9c220bae1fc6..2d540ad40ee73 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/base.py +++ b/providers/google/src/airflow/providers/google/cloud/links/base.py @@ -24,8 +24,7 @@ from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import Context + from airflow.providers.common.compat.sdk import Context, TaskInstanceKey from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator BASE_LINK = "https://console.cloud.google.com" diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py index 695c86a8e885d..b527e65573b50 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py @@ -29,8 +29,7 @@ from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import Context + from airflow.providers.common.compat.sdk import Context, TaskInstanceKey from airflow.providers.google.version_compat import BaseOperator diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py index 583683bb07130..5381588710daa 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -39,8 +39,7 @@ if TYPE_CHECKING: from google.protobuf.field_mask_pb2 import FieldMask - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import Context + from airflow.providers.common.compat.sdk import Context, TaskInstanceKey from airflow.providers.google.version_compat import BaseOperator BASE_LINK = "https://console.cloud.google.com" diff --git a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py index 9d2058c74557c..469374b133d3e 100644 --- a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py +++ b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py @@ -21,8 +21,7 @@ from airflow.providers.common.compat.sdk import BaseOperatorLink, XCom if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import Context + from airflow.providers.common.compat.sdk import Context, TaskInstanceKey from airflow.providers.google.version_compat import BaseOperator diff --git a/providers/microsoft/azure/pyproject.toml b/providers/microsoft/azure/pyproject.toml index 4a5f41110ec0f..a0ef35c5f5c50 100644 --- a/providers/microsoft/azure/pyproject.toml +++ b/providers/microsoft/azure/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version "adlfs>=2023.10.0", "azure-batch>=8.0.0", "azure-cosmos>=4.6.0", diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py index dba2bb126caaa..495eb27d6a7f6 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/data_factory.py @@ -40,7 +40,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.sdk import Context diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/powerbi.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/powerbi.py index 8a2a7292d7b53..ea78e3cab50c4 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/powerbi.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/powerbi.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from msgraph_core import APIVersion - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.sdk import Context diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py index 9cb51556cee39..5bffb6dc0d8fc 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/synapse.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from azure.synapse.spark.models import SparkBatchJobOptions - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.common.compat.sdk import TaskInstanceKey from airflow.sdk import Context diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 4ed76b057af5e..ae3f978da4349 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -58,8 +58,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import Context + from airflow.providers.common.compat.sdk import Context, TaskInstanceKey class DagIsPaused(AirflowException): @@ -89,8 +88,17 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: trigger_dag_id = operator.trigger_dag_id if not AIRFLOW_V_3_0_PLUS: from airflow.models.renderedtifields import RenderedTaskInstanceFields + from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey + + core_ti_key = CoreTaskInstanceKey( + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + run_id=ti_key.run_id, + try_number=ti_key.try_number, + map_index=ti_key.map_index, + ) - if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key): + if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] # Fetch the correct dag_run_id for the triggerED dag which is diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 18e6578830661..91e7821d304c9 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -62,8 +62,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import Context + from airflow.providers.common.compat.sdk import Context, TaskInstanceKey class ExternalDagLink(BaseOperatorLink): @@ -83,8 +82,17 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: if not AIRFLOW_V_3_0_PLUS: from airflow.models.renderedtifields import RenderedTaskInstanceFields + from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey + + core_ti_key = CoreTaskInstanceKey( + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + run_id=ti_key.run_id, + try_number=ti_key.try_number, + map_index=ti_key.map_index, + ) - if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key): + if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): external_dag_id: str = template_fields.get("external_dag_id", operator.external_dag_id) # type: ignore[no-redef] if AIRFLOW_V_3_0_PLUS: diff --git a/providers/yandex/pyproject.toml b/providers/yandex/pyproject.toml index 8786510b78519..01ee2b892e1bf 100644 --- a/providers/yandex/pyproject.toml +++ b/providers/yandex/pyproject.toml @@ -60,7 +60,7 @@ dependencies = [ "apache-airflow>=2.11.0", "yandexcloud>=0.308.0; python_version < '3.13'", "yandex-query-client>=0.1.4; python_version < '3.13'", - "apache-airflow-providers-common-compat>=1.12.0", + "apache-airflow-providers-common-compat>=1.12.0", # use next version ] [dependency-groups] diff --git a/providers/yandex/src/airflow/providers/yandex/links/yq.py b/providers/yandex/src/airflow/providers/yandex/links/yq.py index c9f5195c1d34d..015d68bb81121 100644 --- a/providers/yandex/src/airflow/providers/yandex/links/yq.py +++ b/providers/yandex/src/airflow/providers/yandex/links/yq.py @@ -21,8 +21,7 @@ from airflow.providers.common.compat.sdk import BaseOperatorLink, XCom if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.common.compat.sdk import BaseOperator, Context + from airflow.providers.common.compat.sdk import BaseOperator, Context, TaskInstanceKey XCOM_WEBLINK_KEY = "web_link" diff --git a/task-sdk/src/airflow/sdk/bases/operatorlink.py b/task-sdk/src/airflow/sdk/bases/operatorlink.py index 43ffe0725f5d5..ebc5a4dcd04db 100644 --- a/task-sdk/src/airflow/sdk/bases/operatorlink.py +++ b/task-sdk/src/airflow/sdk/bases/operatorlink.py @@ -23,8 +23,8 @@ import attrs if TYPE_CHECKING: - from airflow.models.taskinstancekey import TaskInstanceKey from airflow.sdk import BaseOperator + from airflow.sdk.types import TaskInstanceKey @attrs.define() diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 9bb5e4f480f12..237c36d36cd35 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -19,7 +19,7 @@ import uuid from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Protocol, TypeAlias +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypeAlias from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet @@ -39,6 +39,39 @@ Operator: TypeAlias = BaseOperator | MappedOperator +class TaskInstanceKey(NamedTuple): + """Key used to identify task instance.""" + + dag_id: str + task_id: str + run_id: str + try_number: int = 1 + map_index: int = -1 + + @property + def primary(self) -> tuple[str, str, str, int]: + """Return task instance primary key part of the key.""" + return self.dag_id, self.task_id, self.run_id, self.map_index + + def with_try_number(self, try_number: int) -> TaskInstanceKey: + """Return TaskInstanceKey with provided ``try_number``.""" + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index) + + @property + def key(self) -> TaskInstanceKey: + """ + For API-compatibility with TaskInstance. + + Returns self + """ + return self + + @classmethod + def from_dict(cls, dictionary): + """Create TaskInstanceKey from dictionary.""" + return cls(**dictionary) + + class DagRunProtocol(Protocol): """Minimal interface for a Dag run available during the execution."""