Skip to content

Commit c5c73f5

Browse files
committed
RHOAIENG-8098 - ClusterConfiguration should support tolerations
1 parent 6b0a3cc commit c5c73f5

File tree

5 files changed

+62
-9
lines changed

5 files changed

+62
-9
lines changed

src/codeflare_sdk/common/utils/unit_test_support.py

+11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import yaml
2323
from pathlib import Path
2424
from kubernetes import client
25+
from kubernetes.client import V1Toleration
2526
from unittest.mock import patch
2627

2728
parent = Path(__file__).resolve().parents[4] # project directory
@@ -426,8 +427,18 @@ def create_cluster_all_config_params(mocker, cluster_name, is_appwrapper) -> Clu
426427
head_memory_requests=12,
427428
head_memory_limits=16,
428429
head_extended_resource_requests={"nvidia.com/gpu": 1, "intel.com/gpu": 2},
430+
head_tolerations=[
431+
V1Toleration(
432+
key="key1", operator="Equal", value="value1", effect="NoSchedule"
433+
)
434+
],
429435
worker_cpu_requests=4,
430436
worker_cpu_limits=8,
437+
tolerations=[
438+
V1Toleration(
439+
key="key2", operator="Equal", value="value2", effect="NoSchedule"
440+
)
441+
],
431442
num_workers=10,
432443
worker_memory_requests=12,
433444
worker_memory_limits=16,

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
This sub-module exists primarily to be used internally by the Cluster object
1717
(in the cluster sub-module) for RayCluster/AppWrapper generation.
1818
"""
19-
from typing import Union, Tuple, Dict
19+
from typing import List, Union, Tuple, Dict
2020
from ...common import _kube_api_error_handling
2121
from ...common.kubernetes_cluster import get_api_client, config_check
2222
from kubernetes.client.exceptions import ApiException
@@ -40,6 +40,7 @@
4040
V1PodTemplateSpec,
4141
V1PodSpec,
4242
V1LocalObjectReference,
43+
V1Toleration,
4344
)
4445

4546
import yaml
@@ -139,7 +140,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
139140
"resources": head_resources,
140141
},
141142
"template": {
142-
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
143+
"spec": get_pod_spec(
144+
cluster,
145+
[get_head_container_spec(cluster)],
146+
cluster.config.head_tolerations,
147+
)
143148
},
144149
},
145150
"workerGroupSpecs": [
@@ -154,7 +159,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
154159
"resources": worker_resources,
155160
},
156161
"template": V1PodTemplateSpec(
157-
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
162+
spec=get_pod_spec(
163+
cluster,
164+
[get_worker_container_spec(cluster)],
165+
cluster.config.tolerations,
166+
)
158167
),
159168
}
160169
],
@@ -243,14 +252,17 @@ def update_image(image) -> str:
243252
return image
244253

245254

246-
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
255+
def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers, tolerations):
247256
"""
248257
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
249258
"""
250-
pod_spec = V1PodSpec(
251-
containers=containers,
252-
volumes=VOLUMES,
253-
)
259+
if tolerations is None:
260+
pod_spec = V1PodSpec(containers=containers, volumes=VOLUMES)
261+
else:
262+
pod_spec = V1PodSpec(
263+
containers=containers, volumes=VOLUMES, tolerations=tolerations
264+
)
265+
254266
if cluster.config.image_pull_secrets != []:
255267
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
256268

src/codeflare_sdk/ray/cluster/config.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Toleration
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -57,6 +58,8 @@ class ClusterConfiguration:
5758
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
5859
head_extended_resource_requests:
5960
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61+
head_tolerations:
62+
List of tolerations for head nodes.
6063
min_cpus:
6164
The minimum number of CPUs to allocate to each worker.
6265
max_cpus:
@@ -69,6 +72,8 @@ class ClusterConfiguration:
6972
The maximum amount of memory to allocate to each worker.
7073
num_gpus:
7174
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75+
tolerations:
76+
List of tolerations for worker nodes.
7277
appwrapper:
7378
A boolean indicating whether to use an AppWrapper.
7479
envs:
@@ -105,6 +110,7 @@ class ClusterConfiguration:
105110
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
106111
default_factory=dict
107112
)
113+
head_tolerations: Optional[List[V1Toleration]] = None
108114
worker_cpu_requests: Union[int, str] = 1
109115
worker_cpu_limits: Union[int, str] = 1
110116
min_cpus: Optional[Union[int, str]] = None # Deprecating
@@ -115,6 +121,7 @@ class ClusterConfiguration:
115121
min_memory: Optional[Union[int, str]] = None # Deprecating
116122
max_memory: Optional[Union[int, str]] = None # Deprecating
117123
num_gpus: Optional[int] = None # Deprecating
124+
tolerations: Optional[List[V1Toleration]] = None
118125
appwrapper: bool = False
119126
envs: Dict[str, str] = field(default_factory=dict)
120127
image: str = ""
@@ -265,7 +272,10 @@ def check_type(value, expected_type):
265272
if origin_type is Union:
266273
return any(check_type(value, union_type) for union_type in args)
267274
if origin_type is list:
268-
return all(check_type(elem, args[0]) for elem in value)
275+
if value is not None:
276+
return all(check_type(elem, args[0]) for elem in value)
277+
else:
278+
return True
269279
if origin_type is dict:
270280
return all(
271281
check_type(k, args[0]) and check_type(v, args[1])

tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ spec:
9393
imagePullSecrets:
9494
- name: secret1
9595
- name: secret2
96+
tolerations:
97+
- effect: NoSchedule
98+
key: key1
99+
operator: Equal
100+
value: value1
96101
volumes:
97102
- configMap:
98103
items:
@@ -161,6 +166,11 @@ spec:
161166
imagePullSecrets:
162167
- name: secret1
163168
- name: secret2
169+
tolerations:
170+
- effect: NoSchedule
171+
key: key2
172+
operator: Equal
173+
value: value2
164174
volumes:
165175
- configMap:
166176
items:

tests/test_cluster_yamls/ray/unit-test-all-params.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ spec:
8484
imagePullSecrets:
8585
- name: secret1
8686
- name: secret2
87+
tolerations:
88+
- effect: NoSchedule
89+
key: key1
90+
operator: Equal
91+
value: value1
8792
volumes:
8893
- configMap:
8994
items:
@@ -152,6 +157,11 @@ spec:
152157
imagePullSecrets:
153158
- name: secret1
154159
- name: secret2
160+
tolerations:
161+
- effect: NoSchedule
162+
key: key2
163+
operator: Equal
164+
value: value2
155165
volumes:
156166
- configMap:
157167
items:

0 commit comments

Comments
 (0)