Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6747593

Browse files
Add Ray Job operators
1 parent b906080 commit 6747593

15 files changed

Lines changed: 1650 additions & 1 deletion

File tree

airflow-core/tests/unit/always/test_project_structure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
470470
"airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaBaseOperator",
471471
"airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator",
472472
"airflow.providers.google.cloud.operators.vertex_ai.ray.RayBaseOperator",
473+
"airflow.providers.google.cloud.operators.ray.RayJobBaseOperator",
473474
"airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator",
474475
"airflow.providers.google.marketing_platform.operators.search_ads._GoogleSearchAdsBaseOperator",
475476
}

docs/spelling_wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,7 @@ Jira
981981
jira
982982
jitter
983983
JobComplete
984+
JobDetails
984985
JobExists
985986
jobflow
986987
jobId
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
Ray Job Operators
19+
=================
20+
21+
The Ray Job operators provide a high-level interface for interacting with remote Ray clusters
22+
using the Ray Jobs API. These operators can be used with clusters running on Google Cloud Vertex AI Ray,
23+
GKE (self-managed Ray clusters) or any Ray cluster reachable through a dashboard address or Ray Client address.
24+
25+
The operators allow you to submit jobs, monitor their progress, retrieve logs,
26+
and manage job lifecycle from Airflow.
27+
28+
Submitting Ray Jobs
29+
^^^^^^^^^^^^^^^^^^^
30+
31+
The :class:`~airflow.providers.google.cloud.operators.ray.RaySubmitJobOperator`
32+
submits a job to a Ray cluster and optionally waits for completion.
33+
34+
It supports waiting for job completion with ``wait_for_job_done``
35+
and retrieving logs after completion with ``get_job_logs`` parameters.
36+
37+
.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py
38+
:language: python
39+
:dedent: 4
40+
:start-after: [START how_to_ray_submit_job]
41+
:end-before: [END how_to_ray_submit_job]
42+
43+
Stopping Ray Jobs
44+
^^^^^^^^^^^^^^^^^
45+
46+
Use :class:`~airflow.providers.google.cloud.operators.ray.RayStopJobOperator`
47+
to stop a running Ray job identified by its job ID.
48+
49+
.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py
50+
:language: python
51+
:dedent: 4
52+
:start-after: [START how_to_ray_stop_job]
53+
:end-before: [END how_to_ray_stop_job]
54+
55+
Deleting Ray Jobs
56+
^^^^^^^^^^^^^^^^^
57+
58+
Use :class:`~airflow.providers.google.cloud.operators.ray.RayDeleteJobOperator`
59+
to delete a job and its metadata after it reaches a terminal state.
60+
61+
.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py
62+
:language: python
63+
:dedent: 4
64+
:start-after: [START how_to_ray_delete_job]
65+
:end-before: [END how_to_ray_delete_job]
66+
67+
Retrieving Job Information
68+
^^^^^^^^^^^^^^^^^^^^^^^^^^
69+
70+
The :class:`~airflow.providers.google.cloud.operators.ray.RayGetJobInfoOperator`
71+
retrieves detailed information about a Ray job, including status, timestamps,
72+
entrypoint, metadata, and runtime environment.
73+
74+
.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py
75+
:language: python
76+
:dedent: 4
77+
:start-after: [START how_to_ray_get_job_info]
78+
:end-before: [END how_to_ray_get_job_info]
79+
80+
Listing Jobs
81+
^^^^^^^^^^^^
82+
83+
Use :class:`~airflow.providers.google.cloud.operators.ray.RayListJobsOperator`
84+
to list all jobs that have run on the cluster.
85+
86+
.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py
87+
:language: python
88+
:dedent: 4
89+
:start-after: [START how_to_ray_list_jobs]
90+
:end-before: [END how_to_ray_list_jobs]

providers/google/provider.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,11 @@ integrations:
462462
how-to-guide:
463463
- /docs/apache-airflow-providers-google/operators/cloud/gen_ai.rst
464464
tags: [gcp]
465+
- integration-name: Google Ray
466+
external-doc-url: https://docs.cloud.google.com/vertex-ai/docs/open-source/ray-on-vertex-ai/overview
467+
how-to-guide:
468+
- /docs/apache-airflow-providers-google/operators/cloud/ray.rst
469+
tags: [gcp]
465470

466471
operators:
467472
- integration-name: Google Ads
@@ -624,6 +629,9 @@ operators:
624629
- integration-name: Google Cloud Generative AI
625630
python-modules:
626631
- airflow.providers.google.cloud.operators.gen_ai
632+
- integration-name: Google Ray
633+
python-modules:
634+
- airflow.providers.google.cloud.operators.ray
627635

628636
sensors:
629637
- integration-name: Google BigQuery
@@ -905,7 +913,9 @@ hooks:
905913
- integration-name: Google Cloud Generative AI
906914
python-modules:
907915
- airflow.providers.google.cloud.hooks.gen_ai
908-
916+
- integration-name: Google Ray
917+
python-modules:
918+
- airflow.providers.google.cloud.hooks.ray
909919

910920
triggers:
911921
- integration-name: Google BigQuery Data Transfer Service
@@ -1254,6 +1264,7 @@ extra-links:
12541264
- airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink
12551265
- airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink
12561266
- airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink
1267+
- airflow.providers.google.cloud.links.ray.RayJobLink
12571268

12581269
secrets-backends:
12591270
- airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
"""This module contains a Google Cloud Ray Job hook."""
19+
20+
from __future__ import annotations
21+
22+
from typing import TYPE_CHECKING, Any
23+
from urllib.parse import urlparse
24+
25+
from ray.job_submission import JobSubmissionClient
26+
27+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
28+
29+
if TYPE_CHECKING:
30+
from ray.dashboard.modules.job.common import JobStatus
31+
from ray.dashboard.modules.job.pydantic_models import JobDetails
32+
33+
VERTEX_RAY_DOMAIN = "aiplatform-training.googleusercontent.com"
34+
35+
36+
class RayJobHook(GoogleBaseHook):
37+
"""Hook for Jobs APIs."""
38+
39+
def _is_vertex_ray_address(self, address: str) -> bool:
40+
"""Return True if address points to Vertex Ray dashboard host."""
41+
parsed = urlparse(address if "://" in address else f"https://{address}")
42+
hostname = parsed.hostname
43+
if not hostname:
44+
return False
45+
return hostname.endswith(VERTEX_RAY_DOMAIN)
46+
47+
def get_client(self, address: str) -> JobSubmissionClient:
48+
"""
49+
Create a client for submitting and interacting with jobs on a remote cluster.
50+
51+
:param address: Either (1) the address of the Ray cluster, or (2) the HTTP address
52+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
53+
In case (1) it must be specified as an address that can be passed to
54+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
55+
or "auto", or "localhost:<port>".
56+
"""
57+
if self._is_vertex_ray_address(address):
58+
return JobSubmissionClient(f"vertex_ray://{address}")
59+
return JobSubmissionClient(address=address)
60+
61+
def serialize_job_obj(self, job_obj: JobDetails) -> dict:
62+
"""Serialize JobDetails to a plain dict."""
63+
if hasattr(job_obj, "model_dump"): # Pydantic v2
64+
return job_obj.model_dump(exclude_none=True)
65+
if hasattr(job_obj, "dict"): # Pydantic v1
66+
return job_obj.dict(exclude_none=True)
67+
return dict(job_obj)
68+
69+
def submit_job(
70+
self,
71+
entrypoint: str,
72+
cluster_address: str,
73+
runtime_env: dict[str, Any] | None = None,
74+
metadata: dict[str, str] | None = None,
75+
submission_id: str | None = None,
76+
entrypoint_num_cpus: int | float | None = None,
77+
entrypoint_num_gpus: int | float | None = None,
78+
entrypoint_memory: int | None = None,
79+
entrypoint_resources: dict[str, float] | None = None,
80+
) -> str:
81+
"""
82+
Submit and execute Job on Ray cluster.
83+
84+
When a job is submitted, it runs once to completion or failure. Retries or
85+
different runs with different parameters should be handled by the
86+
submitter. Jobs are bound to the lifetime of a Ray cluster, so if the
87+
cluster goes down, all running jobs on that cluster will be terminated.
88+
89+
:param entrypoint: Required. The shell command to run for this job.
90+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
91+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
92+
In case (1) it must be specified as an address that can be passed to
93+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
94+
or "auto", or "localhost:<port>".
95+
:param submission_id: A unique ID for this job.
96+
:param runtime_env: The runtime environment to install and run this job in.
97+
:param metadata: Arbitrary data to store along with this job.
98+
:param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution
99+
of the entrypoint command, separately from any tasks or actors launched
100+
by it. Defaults to 0.
101+
:param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution
102+
of the entrypoint command, separately from any tasks or actors launched
103+
by it. Defaults to 0.
104+
:param entrypoint_memory: The quantity of memory to reserve for the
105+
execution of the entrypoint command, separately from any tasks or
106+
actors launched by it. Defaults to 0.
107+
:param entrypoint_resources: The quantity of custom resources to reserve for the
108+
execution of the entrypoint command, separately from any tasks or
109+
actors launched by it.
110+
"""
111+
job_id = self.get_client(address=cluster_address).submit_job(
112+
entrypoint=entrypoint,
113+
runtime_env=runtime_env,
114+
metadata=metadata,
115+
submission_id=submission_id,
116+
entrypoint_num_cpus=entrypoint_num_cpus,
117+
entrypoint_num_gpus=entrypoint_num_gpus,
118+
entrypoint_memory=entrypoint_memory,
119+
entrypoint_resources=entrypoint_resources,
120+
)
121+
return job_id
122+
123+
def stop_job(
124+
self,
125+
job_id: str,
126+
cluster_address: str,
127+
) -> bool:
128+
"""
129+
Stop Job on Ray cluster.
130+
131+
:param job_id: Required. The job ID or submission ID for the job to be stopped.
132+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
133+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
134+
In case (1) it must be specified as an address that can be passed to
135+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
136+
or "auto", or "localhost:<port>".
137+
:return: True if the job was stopped, otherwise False.
138+
"""
139+
return self.get_client(address=cluster_address).stop_job(job_id=job_id)
140+
141+
def delete_job(
142+
self,
143+
job_id: str,
144+
cluster_address: str,
145+
) -> bool:
146+
"""
147+
Delete Job on Ray cluster in a terminal state and all of its associated data.
148+
149+
If the job is not already in a terminal state, raises an error.
150+
This does not delete the job logs from disk.
151+
Submitting a job with the same submission ID as a previously
152+
deleted job is not supported and may lead to unexpected behavior.
153+
154+
:param job_id: Required. The job ID or submission ID for the job to be deleted.
155+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
156+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
157+
In case (1) it must be specified as an address that can be passed to
158+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
159+
or "auto", or "localhost:<port>".
160+
:return: True if the job was deleted, otherwise False.
161+
"""
162+
return self.get_client(address=cluster_address).delete_job(job_id=job_id)
163+
164+
def get_job_info(
165+
self,
166+
job_id: str,
167+
cluster_address: str,
168+
) -> JobDetails:
169+
"""
170+
Get the latest status and other information associated with a Job on Ray cluster.
171+
172+
:param job_id: Required. The job ID or submission ID for the job to be retrieved.
173+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
174+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
175+
In case (1) it must be specified as an address that can be passed to
176+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
177+
or "auto", or "localhost:<port>".
178+
:return: The JobDetails for the job.
179+
"""
180+
return self.get_client(address=cluster_address).get_job_info(job_id=job_id)
181+
182+
def list_jobs(
183+
self,
184+
cluster_address: str,
185+
) -> list[JobDetails]:
186+
"""
187+
List all jobs along with their status and other information.
188+
189+
Lists all jobs that have ever run on the cluster, including jobs that are
190+
currently running and jobs that are no longer running.
191+
192+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
193+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
194+
In case (1) it must be specified as an address that can be passed to
195+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
196+
or "auto", or "localhost:<port>".
197+
"""
198+
return self.get_client(address=cluster_address).list_jobs()
199+
200+
def get_job_status(
201+
self,
202+
job_id: str,
203+
cluster_address: str,
204+
) -> JobStatus:
205+
"""
206+
Get the most recent status of a Job on Ray cluster.
207+
208+
:param job_id: Required. The job ID or submission ID for the job to be retrieved.
209+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
210+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
211+
In case (1) it must be specified as an address that can be passed to
212+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
213+
or "auto", or "localhost:<port>".
214+
:return: The JobStatus of the job.
215+
"""
216+
return self.get_client(address=cluster_address).get_job_status(job_id=job_id)
217+
218+
def get_job_logs(
219+
self,
220+
job_id: str,
221+
cluster_address: str,
222+
) -> str:
223+
"""
224+
Get all logs produced by a Job on Ray cluster.
225+
226+
:param job_id: Required. The job ID or submission ID for the job to be retrieved.
227+
:param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address
228+
of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
229+
In case (1) it must be specified as an address that can be passed to
230+
ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
231+
or "auto", or "localhost:<port>".
232+
:return: A string containing the full logs of the job.
233+
"""
234+
return self.get_client(address=cluster_address).get_job_logs(job_id=job_id)

0 commit comments

Comments
 (0)