Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions app/deployments/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,7 @@ def patch_template(
k_def_gpus = "num_gpus"
k_def_gpu_model = "gpu_model"
k_def_gpu_vendor = "gpu_vendor"
k_def_enable_gpu = "enable_gpu"

# override template flavors with provider flavors
for k, v in list(template[k_inputs].items()):
Expand All @@ -1949,6 +1950,7 @@ def patch_template(
k_gpus = None
k_gpu_model = None
k_gpu_vendor = None
k_enable_gpu = None
if k_constraints in v:
for ff in v[k_constraints]:
# search for cpu key
Expand Down Expand Up @@ -1987,6 +1989,12 @@ def patch_template(
if re.search(k_def_gpu_vendor, fk):
k_gpu_vendor = fk
break
# search for enable_gpu key
if not k_enable_gpu:
for fk in ff[k_set].keys():
if re.search(k_def_enable_gpu, fk):
k_enable_gpu = fk
break
if k_group_overrides in v:
for kk,vv in v[k_group_overrides].items():
if k_constraints in vv:
Expand Down Expand Up @@ -2027,6 +2035,12 @@ def patch_template(
if re.search(k_def_gpu_vendor, fk):
k_gpu_vendor = fk
break
# search for enable_gpu key
if not k_enable_gpu:
for fk in ff[k_set].keys():
if re.search(k_def_enable_gpu, fk):
k_enable_gpu = fk
break

if not k_mem:
k_mem = k_def_mem
Expand All @@ -2040,6 +2054,8 @@ def patch_template(
k_gpu_model = k_def_gpu_model
if not k_gpu_vendor:
k_gpu_vendor = k_def_gpu_vendor
if not k_enable_gpu:
k_enable_gpu = k_def_enable_gpu

rflavors = list()

Expand Down Expand Up @@ -2101,6 +2117,13 @@ def patch_template(
if isinstance(c, dict):
if k_valid_values in c:
valid_values[k_gpu_vendor] = c.get(k_valid_values)

if k_enable_gpu in template[k_inputs] and not k_enable_gpu in valid_values:
if k_constraints in template[k_inputs][k_enable_gpu]:
for c in template[k_inputs][k_enable_gpu][k_constraints]:
if isinstance(c, dict):
if k_valid_values in c:
valid_values[k_enable_gpu] = c.get(k_valid_values)

for f in ff:
#filter constraints
Expand Down Expand Up @@ -2140,6 +2163,10 @@ def patch_template(
if not f[k_set][k_def_gpu_vendor].lower() in [x.lower() for x in valid_values[k_gpu_vendor]]:
continue

if k_enable_gpu in valid_values:
if not f[k_set][k_def_enable_gpu].lower() in [x.lower() for x in valid_values[k_enable_gpu]]:
continue


flavor = {
"value": f["value"],
Expand All @@ -2151,6 +2178,7 @@ def patch_template(
k_gpus: "{}".format(f[k_set][k_def_gpus]),
k_gpu_model: "{}".format(f[k_set][k_def_gpu_model]),
k_gpu_vendor: "{}".format(f[k_set][k_def_gpu_vendor]),
k_enable_gpu: "{}".format(f[k_set][k_def_enable_gpu]),
},
}
rflavors.append(flavor)
Expand Down