Skip to content

Commit 2f2cdca

Browse files
committed
RHOAIENG-8098 - ClusterConfiguration should support tolerations
1 parent 6b0a3cc commit 2f2cdca

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
This sub-module exists primarily to be used internally by the Cluster object
1717
(in the cluster sub-module) for RayCluster/AppWrapper generation.
1818
"""
19-
from typing import Union, Tuple, Dict
19+
from typing import List, Union, Tuple, Dict
2020
from ...common import _kube_api_error_handling
2121
from ...common.kubernetes_cluster import get_api_client, config_check
2222
from kubernetes.client.exceptions import ApiException
@@ -40,6 +40,7 @@
4040
V1PodTemplateSpec,
4141
V1PodSpec,
4242
V1LocalObjectReference,
43+
V1Toleration
4344
)
4445

4546
import yaml
@@ -139,7 +140,8 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
139140
"resources": head_resources,
140141
},
141142
"template": {
142-
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
143+
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)],
144+
cluster.config.head_tolerations)
143145
},
144146
},
145147
"workerGroupSpecs": [
@@ -154,7 +156,8 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
154156
"resources": worker_resources,
155157
},
156158
"template": V1PodTemplateSpec(
157-
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
159+
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)],
160+
cluster.config.tolerations)
158161
),
159162
}
160163
],
@@ -243,14 +246,22 @@ def update_image(image) -> str:
243246
return image
244247

245248

246-
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
249+
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers, tolerations):
247250
"""
248251
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
249252
"""
250-
pod_spec = V1PodSpec(
251-
containers=containers,
252-
volumes=VOLUMES,
253-
)
253+
if tolerations is None:
254+
pod_spec = V1PodSpec(
255+
containers=containers,
256+
volumes=VOLUMES
257+
)
258+
else:
259+
pod_spec = V1PodSpec(
260+
containers=containers,
261+
volumes=VOLUMES,
262+
tolerations=tolerations
263+
)
264+
254265
if cluster.config.image_pull_secrets != []:
255266
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
256267

src/codeflare_sdk/ray/cluster/config.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Toleration
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -57,6 +58,8 @@ class ClusterConfiguration:
5758
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
5859
head_extended_resource_requests:
5960
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61+
head_tolerations:
62+
List of tolerations for head nodes.
6063
min_cpus:
6164
The minimum number of CPUs to allocate to each worker.
6265
max_cpus:
@@ -69,6 +72,8 @@ class ClusterConfiguration:
6972
The maximum amount of memory to allocate to each worker.
7073
num_gpus:
7174
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75+
tolerations:
76+
List of tolerations for worker nodes.
7277
appwrapper:
7378
A boolean indicating whether to use an AppWrapper.
7479
envs:
@@ -105,6 +110,7 @@ class ClusterConfiguration:
105110
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
106111
default_factory=dict
107112
)
113+
head_tolerations: Optional[List[V1Toleration]] = None
108114
worker_cpu_requests: Union[int, str] = 1
109115
worker_cpu_limits: Union[int, str] = 1
110116
min_cpus: Optional[Union[int, str]] = None # Deprecating
@@ -115,6 +121,7 @@ class ClusterConfiguration:
115121
min_memory: Optional[Union[int, str]] = None # Deprecating
116122
max_memory: Optional[Union[int, str]] = None # Deprecating
117123
num_gpus: Optional[int] = None # Deprecating
124+
tolerations: Optional[List[V1Toleration]] = None
118125
appwrapper: bool = False
119126
envs: Dict[str, str] = field(default_factory=dict)
120127
image: str = ""
@@ -265,7 +272,10 @@ def check_type(value, expected_type):
265272
if origin_type is Union:
266273
return any(check_type(value, union_type) for union_type in args)
267274
if origin_type is list:
268-
return all(check_type(elem, args[0]) for elem in value)
275+
if value is not None:
276+
return all(check_type(elem, args[0]) for elem in value)
277+
else:
278+
return True
269279
if origin_type is dict:
270280
return all(
271281
check_type(k, args[0]) and check_type(v, args[1])

0 commit comments

Comments
 (0)