22
22
import warnings
23
23
from dataclasses import dataclass , field , fields
24
24
from typing import Dict , List , Optional , Union , get_args , get_origin
25
+ from kubernetes .client import V1Toleration
25
26
26
27
dir = pathlib .Path (__file__ ).parent .parent .resolve ()
27
28
@@ -57,6 +58,8 @@ class ClusterConfiguration:
57
58
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
58
59
head_extended_resource_requests:
59
60
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61
+ head_tolerations:
62
+ List of tolerations for head nodes.
60
63
min_cpus:
61
64
The minimum number of CPUs to allocate to each worker.
62
65
max_cpus:
@@ -69,6 +72,8 @@ class ClusterConfiguration:
69
72
The maximum amount of memory to allocate to each worker.
70
73
num_gpus:
71
74
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75
+ tolerations:
76
+ List of tolerations for worker nodes.
72
77
appwrapper:
73
78
A boolean indicating whether to use an AppWrapper.
74
79
envs:
@@ -105,6 +110,7 @@ class ClusterConfiguration:
105
110
head_extended_resource_requests : Dict [str , Union [str , int ]] = field (
106
111
default_factory = dict
107
112
)
113
+ head_tolerations : Optional [List [V1Toleration ]] = None
108
114
worker_cpu_requests : Union [int , str ] = 1
109
115
worker_cpu_limits : Union [int , str ] = 1
110
116
min_cpus : Optional [Union [int , str ]] = None # Deprecating
@@ -115,6 +121,7 @@ class ClusterConfiguration:
115
121
min_memory : Optional [Union [int , str ]] = None # Deprecating
116
122
max_memory : Optional [Union [int , str ]] = None # Deprecating
117
123
num_gpus : Optional [int ] = None # Deprecating
124
+ tolerations : Optional [List [V1Toleration ]] = None
118
125
appwrapper : bool = False
119
126
envs : Dict [str , str ] = field (default_factory = dict )
120
127
image : str = ""
@@ -265,7 +272,10 @@ def check_type(value, expected_type):
265
272
if origin_type is Union :
266
273
return any (check_type (value , union_type ) for union_type in args )
267
274
if origin_type is list :
268
- return all (check_type (elem , args [0 ]) for elem in value )
275
+ if value is not None :
276
+ return all (check_type (elem , args [0 ]) for elem in value )
277
+ else :
278
+ return True
269
279
if origin_type is dict :
270
280
return all (
271
281
check_type (k , args [0 ]) and check_type (v , args [1 ])
0 commit comments