From 5077dec0192137cae3ebf37875cdcb7e0e2d81de Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Fri, 7 Feb 2025 19:23:06 -0800 Subject: [PATCH 1/4] Add support for affinity and tolerations, refactor unit tests Signed-off-by: Jason Parraga --- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 8 + .../go/tasks/plugins/k8s/ray/ray_test.go | 237 +++++++++++++----- 2 files changed, 181 insertions(+), 64 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 03b4da1e90..a22d0c49f0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -549,6 +549,14 @@ func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8s podSpec.RuntimeClassName = customPodSpec.RuntimeClassName } + if len(customPodSpec.Tolerations) > 0 { + podSpec.Tolerations = customPodSpec.Tolerations + } + + if customPodSpec.Affinity != nil { + podSpec.Affinity = customPodSpec.Affinity + } + return podSpec, nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 3313e29d18..d973d06640 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -432,6 +432,13 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { } } +type rayPodAssertions struct { + resources *corev1.ResourceRequirements + runtimeClassName *string + tolerations []corev1.Toleration + affinity *corev1.Affinity +} + func TestBuildResourceRayCustomK8SPod(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) @@ -457,76 +464,174 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { nvidiaRuntimeClassName := "nvidia-cdi" - headPodSpecCustomResources := &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "ray-head", - Resources: *expectedHeadResources, - }, + headTolerations := []corev1.Toleration{ + { + Key: "head", + Operator: corev1.TolerationOpEqual, + Value: "true", + Effect: corev1.TaintEffectNoSchedule, }, } - workerPodSpecCustomResources := &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "ray-worker", - Resources: *expectedWorkerResources, - }, + + workerTolerations := []corev1.Toleration{ + { + Key: "worker", + Operator: corev1.TolerationOpEqual, + Value: "true", + Effect: corev1.TaintEffectNoSchedule, }, } - headPodSpecCustomRuntimeClass := &corev1.PodSpec{ - RuntimeClassName: &nvidiaRuntimeClassName, + headPodSpecCustomTolerations := &corev1.PodSpec{ + Tolerations: headTolerations, + } + workerPodSpecCustomTolerations := &corev1.PodSpec{ + Tolerations: workerTolerations, + } + + headAffinity := &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "node-type", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"head-node"}, + }, + }, + }, + }, + }, + }, } - workerPodSpecCustomRuntimeClass := &corev1.PodSpec{ - RuntimeClassName: &nvidiaRuntimeClassName, + + workerAffinity := &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "node-type", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"worker-node"}, + }, + }, + }, + }, + }, + }, } params := []struct { - name string - taskResources *corev1.ResourceRequirements - headK8SPod *core.K8SPod - workerK8SPod *core.K8SPod - expectedSubmitterResources *corev1.ResourceRequirements - expectedHeadResources *corev1.ResourceRequirements - expectedWorkerResources *corev1.ResourceRequirements - expectedSubmitterRuntimeClassName *string - expectedHeadRuntimeClassName *string - expectedWorkerRuntimeClassName *string + name string + headK8SPod *core.K8SPod + workerK8SPod *core.K8SPod + headPodAssertions rayPodAssertions + workerPodAssertions rayPodAssertions }{ { - name: "task resources", - taskResources: resourceRequirements, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: resourceRequirements, - expectedWorkerResources: resourceRequirements, + name: "no customizations", + headK8SPod: &core.K8SPod{}, + workerK8SPod: &core.K8SPod{}, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + }, }, { - name: "custom worker and head resources", - taskResources: resourceRequirements, + name: "custom worker and head resources", headK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, headPodSpecCustomResources), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Resources: *expectedHeadResources, + }, + }, + }), }, workerK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, workerPodSpecCustomResources), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + Resources: *expectedWorkerResources, + }, + }, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: expectedHeadResources, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: expectedWorkerResources, + }, + }, + { + name: "custom runtime class name", + headK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + RuntimeClassName: &nvidiaRuntimeClassName, + }), + }, + workerK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + RuntimeClassName: &nvidiaRuntimeClassName, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + runtimeClassName: &nvidiaRuntimeClassName, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + runtimeClassName: &nvidiaRuntimeClassName, }, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: expectedHeadResources, - expectedWorkerResources: expectedWorkerResources, }, { - name: "custom runtime class name", - taskResources: resourceRequirements, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: resourceRequirements, - expectedWorkerResources: resourceRequirements, + name: "custom tolerations", headK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, headPodSpecCustomRuntimeClass), + PodSpec: transformStructToStructPB(t, headPodSpecCustomTolerations), }, workerK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, workerPodSpecCustomRuntimeClass), + PodSpec: transformStructToStructPB(t, workerPodSpecCustomTolerations), + }, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + tolerations: headTolerations, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + tolerations: workerTolerations, + }, + }, + { + name: "custom affinity", + headK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Affinity: headAffinity, + }), + }, + workerK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Affinity: workerAffinity, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: headAffinity, + }, + workerPodAssertions: rayPodAssertions{ + affinity: workerAffinity, }, - expectedHeadRuntimeClassName: &nvidiaRuntimeClassName, - expectedWorkerRuntimeClassName: &nvidiaRuntimeClassName, }, } @@ -545,7 +650,7 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { } taskTemplate := dummyRayTaskTemplate("ray-id", rayJobInput) - taskContext := dummyRayTaskContext(taskTemplate, p.taskResources, nil, "", serviceAccount) + taskContext := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -553,29 +658,33 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { rayJob, ok := r.(*rayv1.RayJob) assert.True(t, ok) - submitterPodResources := rayJob.Spec.SubmitterPodTemplate.Spec.Containers[0].Resources - assert.EqualValues(t, - p.expectedSubmitterResources, - &submitterPodResources, - ) - headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec headPodResources := headPodSpec.Containers[0].Resources - assert.EqualValues(t, - p.expectedHeadResources, - &headPodResources, - ) + if p.headPodAssertions.resources != nil { + assert.EqualValues(t, + *p.headPodAssertions.resources, + headPodResources, + ) + } - assert.EqualValues(t, p.expectedHeadRuntimeClassName, headPodSpec.RuntimeClassName) + assert.EqualValues(t, p.headPodAssertions.runtimeClassName, headPodSpec.RuntimeClassName) + assert.EqualValues(t, p.headPodAssertions.tolerations, headPodSpec.Tolerations) + assert.EqualValues(t, p.headPodAssertions.affinity, headPodSpec.Affinity) for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs { workerPodSpec := workerGroupSpec.Template.Spec workerPodResources := workerPodSpec.Containers[0].Resources - assert.EqualValues(t, - p.expectedWorkerResources, - &workerPodResources, - ) - assert.EqualValues(t, p.expectedWorkerRuntimeClassName, workerPodSpec.RuntimeClassName) + + if p.workerPodAssertions.resources != nil { + assert.EqualValues(t, + *p.workerPodAssertions.resources, + workerPodResources, + ) + } + + assert.EqualValues(t, p.workerPodAssertions.runtimeClassName, workerPodSpec.RuntimeClassName) + assert.EqualValues(t, p.workerPodAssertions.tolerations, workerPodSpec.Tolerations) + assert.EqualValues(t, p.workerPodAssertions.affinity, workerPodSpec.Affinity) } }) } From e2704490b99d221552aedecf6e7a6e61c86aaf70 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Sat, 8 Feb 2025 10:46:29 -0800 Subject: [PATCH 2/4] Retain coverage Signed-off-by: Jason Parraga --- flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index d973d06640..b09f24642d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -590,10 +590,12 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { }, headPodAssertions: rayPodAssertions{ affinity: &corev1.Affinity{}, + resources: resourceRequirements, runtimeClassName: &nvidiaRuntimeClassName, }, workerPodAssertions: rayPodAssertions{ affinity: &corev1.Affinity{}, + resources: resourceRequirements, runtimeClassName: &nvidiaRuntimeClassName, }, }, @@ -607,10 +609,12 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { }, headPodAssertions: rayPodAssertions{ affinity: &corev1.Affinity{}, + resources: resourceRequirements, tolerations: headTolerations, }, workerPodAssertions: rayPodAssertions{ affinity: &corev1.Affinity{}, + resources: resourceRequirements, tolerations: workerTolerations, }, }, @@ -627,10 +631,12 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { }), }, headPodAssertions: rayPodAssertions{ - affinity: headAffinity, + affinity: headAffinity, + resources: resourceRequirements, }, workerPodAssertions: rayPodAssertions{ - affinity: workerAffinity, + affinity: workerAffinity, + resources: resourceRequirements, }, }, } From 2c7d31ca797f41667216a751709d220b98e42b42 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Tue, 18 Feb 2025 22:00:52 -0800 Subject: [PATCH 3/4] example unit test failure Signed-off-by: Jason Parraga --- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 23 +++----------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index a22d0c49f0..11702647d8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -535,26 +535,9 @@ func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8s "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.GetPodSpec(), err.Error()) } - for _, container := range customPodSpec.Containers { - if container.Name != primaryContainer.Name { // Only support the primary container for now - continue - } - - if len(container.Resources.Requests) > 0 || len(container.Resources.Limits) > 0 { - primaryContainer.Resources = container.Resources - } - } - - if customPodSpec.RuntimeClassName != nil { - podSpec.RuntimeClassName = customPodSpec.RuntimeClassName - } - - if len(customPodSpec.Tolerations) > 0 { - podSpec.Tolerations = customPodSpec.Tolerations - } - - if customPodSpec.Affinity != nil { - podSpec.Affinity = customPodSpec.Affinity + podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainer.Name, "") + if err != nil { + return nil, err } return podSpec, nil From ff4cfc3d07e5998659274a21f95744da99ae2aa3 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Tue, 18 Feb 2025 22:06:09 -0800 Subject: [PATCH 4/4] update unit tests Signed-off-by: Jason Parraga --- .../go/tasks/plugins/k8s/ray/ray_test.go | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index b09f24642d..b92c8b044a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -482,13 +482,6 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { }, } - headPodSpecCustomTolerations := &corev1.PodSpec{ - Tolerations: headTolerations, - } - workerPodSpecCustomTolerations := &corev1.PodSpec{ - Tolerations: workerTolerations, - } - headAffinity := &corev1.Affinity{ NodeAffinity: &corev1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ @@ -581,11 +574,21 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { headK8SPod: &core.K8SPod{ PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ RuntimeClassName: &nvidiaRuntimeClassName, + Containers: []corev1.Container{ + { + Name: "ray-head", + }, + }, }), }, workerK8SPod: &core.K8SPod{ PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ RuntimeClassName: &nvidiaRuntimeClassName, + Containers: []corev1.Container{ + { + Name: "ray-worker", + }, + }, }), }, headPodAssertions: rayPodAssertions{ @@ -602,10 +605,24 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { { name: "custom tolerations", headK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, headPodSpecCustomTolerations), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + }, + }, + Tolerations: headTolerations, + }), }, workerK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, workerPodSpecCustomTolerations), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + }, + }, + Tolerations: workerTolerations, + }), }, headPodAssertions: rayPodAssertions{ affinity: &corev1.Affinity{},