Skip to content

Commit 27e023c

Browse files
committed
Updated tests and load_components
1 parent 10b8a36 commit 27e023c

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,23 @@ def write_components(
692692
print(f"Written to: {output_file_name}")
693693

694694

695-
def load_components(user_yaml: dict, name: str):
695+
def load_components(
696+
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
697+
):
696698
component_list = []
697699
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
700+
lq_name = local_queue or get_default_kueue_name(namespace)
698701
for component in components:
699702
if "generictemplate" in component:
703+
if (
704+
"workload.codeflare.dev/appwrapper"
705+
in component["generictemplate"]["metadata"]["labels"]
706+
):
707+
del component["generictemplate"]["metadata"]["labels"][
708+
"workload.codeflare.dev/appwrapper"
709+
]
710+
labels = component["generictemplate"]["metadata"]["labels"]
711+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
700712
component_list.append(component["generictemplate"])
701713

702714
resources = "---\n" + "---\n".join(
@@ -807,11 +819,11 @@ def generate_appwrapper(
807819
if mcad:
808820
write_user_appwrapper(user_yaml, outfile)
809821
else:
810-
write_components(user_yaml, outfile, local_queue)
822+
write_components(user_yaml, outfile, namespace, local_queue)
811823
return outfile
812824
else:
813825
if mcad:
814826
user_yaml = load_appwrapper(user_yaml, name)
815827
else:
816-
user_yaml = load_components(user_yaml, name)
828+
user_yaml = load_components(user_yaml, name, namespace, local_queue)
817829
return user_yaml

tests/unit_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def get_local_queue(group, version, namespace, plural):
332332

333333

334334
def test_cluster_creation_no_mcad(mocker):
335-
# With written resources
336335
# Create Ray Cluster with no local queue specified
337336
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
338337
mocker.patch(
@@ -358,6 +357,7 @@ def test_cluster_creation_no_mcad(mocker):
358357

359358

360359
def test_cluster_creation_no_mcad_local_queue(mocker):
360+
# With written resources
361361
# Create Ray Cluster with local queue specified
362362
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
363363
mocker.patch(
@@ -394,6 +394,7 @@ def test_cluster_creation_no_mcad_local_queue(mocker):
394394
image="quay.io/project-codeflare/ray:latest-py39-cu118",
395395
write_to_file=False,
396396
mcad=False,
397+
local_queue="local-queue-default",
397398
)
398399
cluster = Cluster(config)
399400
test_resources = []

0 commit comments

Comments
 (0)