Skip to content

Commit f3cb22e

Browse files
committed
RHOAIENG-8098 - ClusterConfiguration should support tolerations
1 parent 6b0a3cc commit f3cb22e

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
This sub-module exists primarily to be used internally by the Cluster object
1717
(in the cluster sub-module) for RayCluster/AppWrapper generation.
1818
"""
19-
from typing import Union, Tuple, Dict
19+
from typing import List, Union, Tuple, Dict
2020
from ...common import _kube_api_error_handling
2121
from ...common.kubernetes_cluster import get_api_client, config_check
2222
from kubernetes.client.exceptions import ApiException
@@ -40,6 +40,7 @@
4040
V1PodTemplateSpec,
4141
V1PodSpec,
4242
V1LocalObjectReference,
43+
V1Toleration,
4344
)
4445

4546
import yaml
@@ -139,7 +140,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
139140
"resources": head_resources,
140141
},
141142
"template": {
142-
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
143+
"spec": get_pod_spec(
144+
cluster,
145+
[get_head_container_spec(cluster)],
146+
cluster.config.head_tolerations,
147+
)
143148
},
144149
},
145150
"workerGroupSpecs": [
@@ -154,7 +159,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
154159
"resources": worker_resources,
155160
},
156161
"template": V1PodTemplateSpec(
157-
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
162+
spec=get_pod_spec(
163+
cluster,
164+
[get_worker_container_spec(cluster)],
165+
cluster.config.tolerations,
166+
)
158167
),
159168
}
160169
],
@@ -243,14 +252,17 @@ def update_image(image) -> str:
243252
return image
244253

245254

246-
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
255+
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers, tolerations):
247256
"""
248257
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
249258
"""
250-
pod_spec = V1PodSpec(
251-
containers=containers,
252-
volumes=VOLUMES,
253-
)
259+
if tolerations is None:
260+
pod_spec = V1PodSpec(containers=containers, volumes=VOLUMES)
261+
else:
262+
pod_spec = V1PodSpec(
263+
containers=containers, volumes=VOLUMES, tolerations=tolerations
264+
)
265+
254266
if cluster.config.image_pull_secrets != []:
255267
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
256268

src/codeflare_sdk/ray/cluster/config.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Toleration
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -57,6 +58,8 @@ class ClusterConfiguration:
5758
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
5859
head_extended_resource_requests:
5960
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61+
head_tolerations:
62+
List of tolerations for head nodes.
6063
min_cpus:
6164
The minimum number of CPUs to allocate to each worker.
6265
max_cpus:
@@ -69,6 +72,8 @@ class ClusterConfiguration:
6972
The maximum amount of memory to allocate to each worker.
7073
num_gpus:
7174
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75+
tolerations:
76+
List of tolerations for worker nodes.
7277
appwrapper:
7378
A boolean indicating whether to use an AppWrapper.
7479
envs:
@@ -105,6 +110,7 @@ class ClusterConfiguration:
105110
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
106111
default_factory=dict
107112
)
113+
head_tolerations: Optional[List[V1Toleration]] = None
108114
worker_cpu_requests: Union[int, str] = 1
109115
worker_cpu_limits: Union[int, str] = 1
110116
min_cpus: Optional[Union[int, str]] = None # Deprecating
@@ -115,6 +121,7 @@ class ClusterConfiguration:
115121
min_memory: Optional[Union[int, str]] = None # Deprecating
116122
max_memory: Optional[Union[int, str]] = None # Deprecating
117123
num_gpus: Optional[int] = None # Deprecating
124+
tolerations: Optional[List[V1Toleration]] = None
118125
appwrapper: bool = False
119126
envs: Dict[str, str] = field(default_factory=dict)
120127
image: str = ""
@@ -265,7 +272,10 @@ def check_type(value, expected_type):
265272
if origin_type is Union:
266273
return any(check_type(value, union_type) for union_type in args)
267274
if origin_type is list:
268-
return all(check_type(elem, args[0]) for elem in value)
275+
if value is not None:
276+
return all(check_type(elem, args[0]) for elem in value)
277+
else:
278+
return True
269279
if origin_type is dict:
270280
return all(
271281
check_type(k, args[0]) and check_type(v, args[1])

0 commit comments

Comments
 (0)