Skip to content

RHOAIENG-8098 - ClusterConfiguration should support tolerations #800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/codeflare_sdk/common/utils/unit_test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import yaml
from pathlib import Path
from kubernetes import client
from kubernetes.client import V1Toleration
from unittest.mock import patch

parent = Path(__file__).resolve().parents[4] # project directory
Expand Down Expand Up @@ -427,8 +428,18 @@ def create_cluster_all_config_params(mocker, cluster_name, is_appwrapper) -> Clu
head_memory_requests=12,
head_memory_limits=16,
head_extended_resource_requests={"nvidia.com/gpu": 1, "intel.com/gpu": 2},
head_tolerations=[
V1Toleration(
key="key1", operator="Equal", value="value1", effect="NoSchedule"
)
],
worker_cpu_requests=4,
worker_cpu_limits=8,
worker_tolerations=[
V1Toleration(
key="key2", operator="Equal", value="value2", effect="NoSchedule"
)
],
num_workers=10,
worker_memory_requests=12,
worker_memory_limits=16,
Expand Down
24 changes: 20 additions & 4 deletions src/codeflare_sdk/ray/cluster/build_ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
This sub-module exists primarily to be used internally by the Cluster object
(in the cluster sub-module) for RayCluster/AppWrapper generation.
"""
from typing import Union, Tuple, Dict
from typing import List, Union, Tuple, Dict
from ...common import _kube_api_error_handling
from ...common.kubernetes_cluster import get_api_client, config_check
from kubernetes.client.exceptions import ApiException
Expand All @@ -40,6 +40,7 @@
V1PodTemplateSpec,
V1PodSpec,
V1LocalObjectReference,
V1Toleration,
)

import yaml
Expand Down Expand Up @@ -139,7 +140,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
"resources": head_resources,
},
"template": {
"spec": get_pod_spec(cluster, [get_head_container_spec(cluster)])
"spec": get_pod_spec(
cluster,
[get_head_container_spec(cluster)],
cluster.config.head_tolerations,
)
},
},
"workerGroupSpecs": [
Expand All @@ -154,7 +159,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
"resources": worker_resources,
},
"template": V1PodTemplateSpec(
spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)])
spec=get_pod_spec(
cluster,
[get_worker_container_spec(cluster)],
cluster.config.worker_tolerations,
)
),
}
],
Expand Down Expand Up @@ -243,14 +252,21 @@ def update_image(image) -> str:
return image


def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
def get_pod_spec(
cluster: "codeflare_sdk.ray.cluster.Cluster",
containers: List,
tolerations: List[V1Toleration],
) -> V1PodSpec:
"""
The get_pod_spec() function generates a V1PodSpec for the head/worker containers
"""

pod_spec = V1PodSpec(
containers=containers,
volumes=generate_custom_storage(cluster.config.volumes, VOLUMES),
tolerations=tolerations or None,
)

if cluster.config.image_pull_secrets != []:
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)

Expand Down
13 changes: 11 additions & 2 deletions src/codeflare_sdk/ray/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union, get_args, get_origin
from kubernetes.client import V1Volume, V1VolumeMount
from kubernetes.client import V1Toleration, V1Volume, V1VolumeMount

dir = pathlib.Path(__file__).parent.parent.resolve()

Expand Down Expand Up @@ -58,6 +58,8 @@ class ClusterConfiguration:
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
head_extended_resource_requests:
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
head_tolerations:
List of tolerations for head nodes.
min_cpus:
The minimum number of CPUs to allocate to each worker.
max_cpus:
Expand All @@ -70,6 +72,8 @@ class ClusterConfiguration:
The maximum amount of memory to allocate to each worker.
num_gpus:
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
worker_tolerations:
List of tolerations for worker nodes.
appwrapper:
A boolean indicating whether to use an AppWrapper.
envs:
Expand Down Expand Up @@ -110,6 +114,7 @@ class ClusterConfiguration:
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
default_factory=dict
)
head_tolerations: Optional[List[V1Toleration]] = None
worker_cpu_requests: Union[int, str] = 1
worker_cpu_limits: Union[int, str] = 1
min_cpus: Optional[Union[int, str]] = None # Deprecating
Expand All @@ -120,6 +125,7 @@ class ClusterConfiguration:
min_memory: Optional[Union[int, str]] = None # Deprecating
max_memory: Optional[Union[int, str]] = None # Deprecating
num_gpus: Optional[int] = None # Deprecating
worker_tolerations: Optional[List[V1Toleration]] = None
appwrapper: bool = False
envs: Dict[str, str] = field(default_factory=dict)
image: str = ""
Expand Down Expand Up @@ -272,7 +278,10 @@ def check_type(value, expected_type):
if origin_type is Union:
return any(check_type(value, union_type) for union_type in args)
if origin_type is list:
return all(check_type(elem, args[0]) for elem in value)
if value is not None:
return all(check_type(elem, args[0]) for elem in (value or []))
else:
return True
if origin_type is dict:
return all(
check_type(k, args[0]) and check_type(v, args[1])
Expand Down
10 changes: 10 additions & 0 deletions tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ spec:
imagePullSecrets:
- name: secret1
- name: secret2
tolerations:
- effect: NoSchedule
key: key1
operator: Equal
value: value1
volumes:
- emptyDir:
sizeLimit: 500Gi
Expand Down Expand Up @@ -185,6 +190,11 @@ spec:
imagePullSecrets:
- name: secret1
- name: secret2
tolerations:
- effect: NoSchedule
key: key2
operator: Equal
value: value2
volumes:
- emptyDir:
sizeLimit: 500Gi
Expand Down
10 changes: 10 additions & 0 deletions tests/test_cluster_yamls/ray/unit-test-all-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ spec:
imagePullSecrets:
- name: secret1
- name: secret2
tolerations:
- effect: NoSchedule
key: key1
operator: Equal
value: value1
volumes:
- emptyDir:
sizeLimit: 500Gi
Expand Down Expand Up @@ -176,6 +181,11 @@ spec:
imagePullSecrets:
- name: secret1
- name: secret2
tolerations:
- effect: NoSchedule
key: key2
operator: Equal
value: value2
volumes:
- emptyDir:
sizeLimit: 500Gi
Expand Down
Loading