Skip to content

Commit 5d1714d

Browse files
committed
test ray job poc
1 parent 353606e commit 5d1714d

File tree

2 files changed

+248
-6
lines changed

2 files changed

+248
-6
lines changed

src/codeflare_sdk/ray/cluster/cluster.py

Lines changed: 245 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@
2020

2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
23+
import copy
2324

24-
from ray.job_submission import JobSubmissionClient
25+
from ray.job_submission import JobSubmissionClient, JobStatus
26+
import time
27+
import uuid
28+
import warnings
29+
30+
from ..job.job import RayJobSpec
2531

2632
from ...common.kubernetes_cluster.auth import (
2733
config_check,
@@ -57,7 +63,6 @@
5763
from kubernetes.client.rest import ApiException
5864

5965
from kubernetes.client.rest import ApiException
60-
import warnings
6166

6267
CF_SDK_FIELD_MANAGER = "codeflare-sdk"
6368

@@ -604,6 +609,238 @@ def _component_resources_down(
604609
yamls = yaml.safe_load_all(self.resource_yaml)
605610
_delete_resources(yamls, namespace, api_instance, cluster_name)
606611

612+
@staticmethod
613+
def run_job_with_managed_cluster(
614+
cluster_config: ClusterConfiguration,
615+
job_config: RayJobSpec,
616+
job_cr_name: Optional[str] = None,
617+
submission_mode: str = "K8sJobMode",
618+
shutdown_after_job_finishes: bool = True,
619+
ttl_seconds_after_finished: Optional[int] = None,
620+
suspend_rayjob_creation: bool = False,
621+
wait_for_completion: bool = True,
622+
job_timeout_seconds: Optional[int] = 3600,
623+
job_polling_interval_seconds: int = 10,
624+
):
625+
"""
626+
Manages the lifecycle of a Ray cluster and a job by creating a RayJob custom resource.
627+
KubeRay operator will then create/delete the RayCluster based on the RayJob definition.
628+
629+
This method will:
630+
1. Generate a RayCluster specification from the provided 'cluster_config'.
631+
2. Construct a RayJob custom resource definition using 'job_config' and embedding the RayCluster spec.
632+
3. Create the RayJob resource in Kubernetes.
633+
4. Optionally, wait for the RayJob to complete or timeout, monitoring its status.
634+
5. The RayCluster lifecycle (creation and deletion) is managed by KubeRay
635+
based on the RayJob's 'shutdownAfterJobFinishes' field.
636+
637+
Args:
638+
cluster_config: Configuration for the Ray cluster to be created by RayJob.
639+
job_config: RayJobSpec object containing job-specific details like entrypoint, runtime_env, etc.
640+
job_cr_name: Name for the RayJob Custom Resource. If None, a unique name is generated.
641+
submission_mode: How the job is submitted ("K8sJobMode" or "RayClientMode").
642+
shutdown_after_job_finishes: If True, RayCluster is deleted after job finishes.
643+
ttl_seconds_after_finished: TTL for RayJob after it's finished.
644+
suspend_rayjob_creation: If True, creates the RayJob in a suspended state.
645+
wait_for_completion: If True, waits for the job to finish.
646+
job_timeout_seconds: Timeout for waiting for job completion.
647+
job_polling_interval_seconds: Interval for polling job status.
648+
649+
Returns:
650+
A dictionary containing details like RayJob CR name, reported job submission ID,
651+
final job status, dashboard URL, and the RayCluster name.
652+
653+
Raises:
654+
TimeoutError: If the job doesn't complete within the specified timeout.
655+
ApiException: For Kubernetes API errors.
656+
ValueError: For configuration issues.
657+
"""
658+
config_check()
659+
k8s_co_api = k8s_client.CustomObjectsApi(get_api_client())
660+
namespace = cluster_config.namespace
661+
662+
if not job_config.entrypoint:
663+
raise ValueError("job_config.entrypoint must be specified.")
664+
665+
# Warn if Pydantic V1/V2 specific fields in RayJobSpec are set, as they are not used for RayJob CR.
666+
if job_config.entrypoint_num_cpus is not None or \
667+
job_config.entrypoint_num_gpus is not None or \
668+
job_config.entrypoint_memory is not None:
669+
warnings.warn(
670+
"RayJobSpec fields 'entrypoint_num_cpus', 'entrypoint_num_gpus', 'entrypoint_memory' "
671+
"are not directly used when creating a RayJob CR. They are primarily for the Ray Job Submission Client. "
672+
"Resource requests for the job driver pod should be configured in the RayCluster head node spec via ClusterConfiguration.",
673+
UserWarning
674+
)
675+
676+
# Generate rayClusterSpec from ClusterConfiguration
677+
temp_config_for_spec = copy.deepcopy(cluster_config)
678+
temp_config_for_spec.appwrapper = False
679+
680+
with warnings.catch_warnings():
681+
warnings.simplefilter("ignore", UserWarning)
682+
dummy_cluster_for_spec = Cluster(temp_config_for_spec)
683+
684+
ray_cluster_cr_dict = dummy_cluster_for_spec.resource_yaml
685+
if not isinstance(ray_cluster_cr_dict, dict) or "spec" not in ray_cluster_cr_dict:
686+
raise ValueError(
687+
"Failed to generate RayCluster CR dictionary from ClusterConfiguration. "
688+
f"Got: {type(ray_cluster_cr_dict)}"
689+
)
690+
ray_cluster_spec = ray_cluster_cr_dict["spec"]
691+
692+
# Prepare RayJob CR
693+
actual_job_cr_name = job_cr_name or f"rayjob-{uuid.uuid4().hex[:10]}"
694+
695+
runtime_env_yaml_str = ""
696+
if job_config.runtime_env:
697+
try:
698+
runtime_env_yaml_str = yaml.dump(job_config.runtime_env)
699+
except yaml.YAMLError as e:
700+
raise ValueError(f"Invalid job_config.runtime_env, failed to dump to YAML: {e}")
701+
702+
ray_job_cr_spec = {
703+
"entrypoint": job_config.entrypoint,
704+
"shutdownAfterJobFinishes": shutdown_after_job_finishes,
705+
"rayClusterSpec": ray_cluster_spec,
706+
"submissionMode": submission_mode,
707+
}
708+
709+
if runtime_env_yaml_str:
710+
ray_job_cr_spec["runtimeEnvYAML"] = runtime_env_yaml_str
711+
if job_config.submission_id:
712+
ray_job_cr_spec["jobId"] = job_config.submission_id
713+
if job_config.metadata:
714+
ray_job_cr_spec["metadata"] = job_config.metadata
715+
if ttl_seconds_after_finished is not None:
716+
ray_job_cr_spec["ttlSecondsAfterFinished"] = ttl_seconds_after_finished
717+
if suspend_rayjob_creation:
718+
ray_job_cr_spec["suspend"] = True
719+
if job_config.entrypoint_resources:
720+
ray_job_cr_spec["entrypointResources"] = job_config.entrypoint_resources
721+
722+
ray_job_cr = {
723+
"apiVersion": "ray.io/v1",
724+
"kind": "RayJob",
725+
"metadata": {
726+
"name": actual_job_cr_name,
727+
"namespace": namespace,
728+
},
729+
"spec": ray_job_cr_spec,
730+
}
731+
732+
returned_job_submission_id = None
733+
final_job_status = "UNKNOWN"
734+
dashboard_url = None
735+
ray_cluster_name_actual = None
736+
737+
try:
738+
print(f"Submitting RayJob '{actual_job_cr_name}' to namespace '{namespace}'...")
739+
k8s_co_api.create_namespaced_custom_object(
740+
group="ray.io",
741+
version="v1",
742+
namespace=namespace,
743+
plural="rayjobs",
744+
body=ray_job_cr,
745+
)
746+
print(f"RayJob '{actual_job_cr_name}' created successfully.")
747+
748+
if wait_for_completion:
749+
print(f"Waiting for RayJob '{actual_job_cr_name}' to complete...")
750+
start_time = time.time()
751+
while True:
752+
try:
753+
ray_job_status_cr = k8s_co_api.get_namespaced_custom_object_status(
754+
group="ray.io",
755+
version="v1",
756+
namespace=namespace,
757+
plural="rayjobs",
758+
name=actual_job_cr_name,
759+
)
760+
except ApiException as e:
761+
if e.status == 404:
762+
print(f"RayJob '{actual_job_cr_name}' status not found yet, retrying...")
763+
time.sleep(job_polling_interval_seconds)
764+
continue
765+
raise
766+
767+
status_field = ray_job_status_cr.get("status", {})
768+
job_deployment_status = status_field.get("jobDeploymentStatus", "UNKNOWN")
769+
current_job_status = status_field.get("jobStatus", "PENDING")
770+
771+
dashboard_url = status_field.get("dashboardURL", dashboard_url)
772+
ray_cluster_name_actual = status_field.get("rayClusterName", ray_cluster_name_actual)
773+
returned_job_submission_id = status_field.get("jobId", job_config.submission_id)
774+
775+
final_job_status = current_job_status
776+
print(
777+
f"RayJob '{actual_job_cr_name}' status: JobDeployment='{job_deployment_status}', Job='{current_job_status}'"
778+
)
779+
780+
if current_job_status in ["SUCCEEDED", "FAILED", "STOPPED"]:
781+
break
782+
783+
if job_timeout_seconds and (time.time() - start_time) > job_timeout_seconds:
784+
try:
785+
ray_job_status_cr_final = k8s_co_api.get_namespaced_custom_object_status(
786+
group="ray.io", version="v1", namespace=namespace, plural="rayjobs", name=actual_job_cr_name
787+
)
788+
status_field_final = ray_job_status_cr_final.get("status", {})
789+
final_job_status = status_field_final.get("jobStatus", final_job_status)
790+
returned_job_submission_id = status_field_final.get("jobId", returned_job_submission_id)
791+
dashboard_url = status_field_final.get("dashboardURL", dashboard_url)
792+
ray_cluster_name_actual = status_field_final.get("rayClusterName", ray_cluster_name_actual)
793+
except Exception:
794+
pass
795+
raise TimeoutError(
796+
f"RayJob '{actual_job_cr_name}' timed out after {job_timeout_seconds} seconds. Last status: {final_job_status}"
797+
)
798+
799+
time.sleep(job_polling_interval_seconds)
800+
801+
print(f"RayJob '{actual_job_cr_name}' finished with status: {final_job_status}")
802+
else:
803+
try:
804+
ray_job_status_cr = k8s_co_api.get_namespaced_custom_object_status(
805+
group="ray.io", version="v1", namespace=namespace, plural="rayjobs", name=actual_job_cr_name
806+
)
807+
status_field = ray_job_status_cr.get("status", {})
808+
final_job_status = status_field.get("jobStatus", "SUBMITTED")
809+
returned_job_submission_id = status_field.get("jobId", job_config.submission_id)
810+
dashboard_url = status_field.get("dashboardURL", dashboard_url)
811+
ray_cluster_name_actual = status_field.get("rayClusterName", ray_cluster_name_actual)
812+
except ApiException as e:
813+
if e.status == 404:
814+
final_job_status = "SUBMITTED_NOT_FOUND"
815+
else:
816+
print(f"Warning: Could not fetch initial status for RayJob '{actual_job_cr_name}': {e}")
817+
final_job_status = "UNKNOWN_API_ERROR"
818+
819+
return {
820+
"job_cr_name": actual_job_cr_name,
821+
"job_submission_id": returned_job_submission_id,
822+
"job_status": final_job_status,
823+
"dashboard_url": dashboard_url,
824+
"ray_cluster_name": ray_cluster_name_actual,
825+
}
826+
827+
except ApiException as e:
828+
print(f"Kubernetes API error during RayJob '{actual_job_cr_name}' management: {e.reason} (status: {e.status})")
829+
final_status_on_error = "ERROR_BEFORE_SUBMISSION"
830+
if actual_job_cr_name:
831+
try:
832+
ray_job_status_cr = k8s_co_api.get_namespaced_custom_object_status(
833+
group="ray.io", version="v1", namespace=namespace, plural="rayjobs", name=actual_job_cr_name
834+
)
835+
status_field = ray_job_status_cr.get("status", {})
836+
final_status_on_error = status_field.get("jobStatus", "UNKNOWN_AFTER_K8S_ERROR")
837+
except Exception:
838+
final_status_on_error = "UNKNOWN_FINAL_STATUS_FETCH_FAILED"
839+
raise
840+
except Exception as e:
841+
print(f"An unexpected error occurred during managed RayJob execution for '{actual_job_cr_name}': {e}")
842+
raise
843+
607844

608845
def list_all_clusters(namespace: str, print_to_console: bool = True):
609846
"""
@@ -760,14 +997,19 @@ def get_cluster(
760997
head_extended_resource_requests=head_extended_resources,
761998
worker_extended_resource_requests=worker_extended_resources,
762999
)
1000+
# 1. Prepare RayClusterSpec from ClusterConfiguration
1001+
# Create a temporary config with appwrapper=False to ensure build_ray_cluster returns RayCluster YAML
1002+
temp_cluster_config_dict = cluster_config.dict(exclude_none=True) # Assuming Pydantic V1 or similar .dict() method
1003+
temp_cluster_config_dict['appwrapper'] = False
1004+
temp_cluster_config_for_spec = ClusterConfiguration(**temp_cluster_config_dict)
7631005
# Ignore the warning here for the lack of a ClusterConfiguration
7641006
with warnings.catch_warnings():
7651007
warnings.filterwarnings(
7661008
"ignore",
7671009
message="Please provide a ClusterConfiguration to initialise the Cluster object",
7681010
)
7691011
cluster = Cluster(None)
770-
cluster.config = cluster_config
1012+
cluster.config = temp_cluster_config_for_spec
7711013

7721014
# Remove auto-generated fields like creationTimestamp, uid and etc.
7731015
remove_autogenerated_fields(resource)

src/codeflare_sdk/ray/job/job.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ class RayJobSpec:
7676
class RayJob:
7777
"""RayJob Custom Resource Definition"""
7878

79-
api_version: str = "ray.io/v1"
80-
kind: str = "RayJob"
81-
8279
metadata: Dict[str, Any]
8380
"""Kubernetes metadata for the job"""
8481

8582
spec: RayJobSpec
8683
"""Job specification"""
84+
85+
api_version: str = "ray.io/v1"
86+
kind: str = "RayJob"
8787

8888
status: Optional[Dict[str, Any]] = None
8989
"""Status of the job (managed by the controller)"""

0 commit comments

Comments
 (0)