diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 03b4da1e90..11702647d8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -535,18 +535,9 @@ func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8s "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.GetPodSpec(), err.Error()) } - for _, container := range customPodSpec.Containers { - if container.Name != primaryContainer.Name { // Only support the primary container for now - continue - } - - if len(container.Resources.Requests) > 0 || len(container.Resources.Limits) > 0 { - primaryContainer.Resources = container.Resources - } - } - - if customPodSpec.RuntimeClassName != nil { - podSpec.RuntimeClassName = customPodSpec.RuntimeClassName + podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainer.Name, "") + if err != nil { + return nil, err } return podSpec, nil diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 3313e29d18..b92c8b044a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -432,6 +432,13 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { } } +type rayPodAssertions struct { + resources *corev1.ResourceRequirements + runtimeClassName *string + tolerations []corev1.Toleration + affinity *corev1.Affinity +} + func TestBuildResourceRayCustomK8SPod(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) @@ -457,76 +464,197 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { nvidiaRuntimeClassName := "nvidia-cdi" - headPodSpecCustomResources := &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "ray-head", - Resources: *expectedHeadResources, - }, + headTolerations := []corev1.Toleration{ + { + Key: "head", + Operator: corev1.TolerationOpEqual, + Value: "true", + Effect: corev1.TaintEffectNoSchedule, }, } - workerPodSpecCustomResources := &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "ray-worker", - Resources: *expectedWorkerResources, - }, + + workerTolerations := []corev1.Toleration{ + { + Key: "worker", + Operator: corev1.TolerationOpEqual, + Value: "true", + Effect: corev1.TaintEffectNoSchedule, }, } - headPodSpecCustomRuntimeClass := &corev1.PodSpec{ - RuntimeClassName: &nvidiaRuntimeClassName, + headAffinity := &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "node-type", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"head-node"}, + }, + }, + }, + }, + }, + }, } - workerPodSpecCustomRuntimeClass := &corev1.PodSpec{ - RuntimeClassName: &nvidiaRuntimeClassName, + + workerAffinity := &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "node-type", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"worker-node"}, + }, + }, + }, + }, + }, + }, } params := []struct { - name string - taskResources *corev1.ResourceRequirements - headK8SPod *core.K8SPod - workerK8SPod *core.K8SPod - expectedSubmitterResources *corev1.ResourceRequirements - expectedHeadResources *corev1.ResourceRequirements - expectedWorkerResources *corev1.ResourceRequirements - expectedSubmitterRuntimeClassName *string - expectedHeadRuntimeClassName *string - expectedWorkerRuntimeClassName *string + name string + headK8SPod *core.K8SPod + workerK8SPod *core.K8SPod + headPodAssertions rayPodAssertions + workerPodAssertions rayPodAssertions }{ { - name: "task resources", - taskResources: resourceRequirements, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: resourceRequirements, - expectedWorkerResources: resourceRequirements, + name: "no customizations", + headK8SPod: &core.K8SPod{}, + workerK8SPod: &core.K8SPod{}, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + }, + }, + { + name: "custom worker and head resources", + headK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Resources: *expectedHeadResources, + }, + }, + }), + }, + workerK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + Resources: *expectedWorkerResources, + }, + }, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: expectedHeadResources, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: expectedWorkerResources, + }, + }, + { + name: "custom runtime class name", + headK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + RuntimeClassName: &nvidiaRuntimeClassName, + Containers: []corev1.Container{ + { + Name: "ray-head", + }, + }, + }), + }, + workerK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + RuntimeClassName: &nvidiaRuntimeClassName, + Containers: []corev1.Container{ + { + Name: "ray-worker", + }, + }, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + runtimeClassName: &nvidiaRuntimeClassName, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + runtimeClassName: &nvidiaRuntimeClassName, + }, }, { - name: "custom worker and head resources", - taskResources: resourceRequirements, + name: "custom tolerations", headK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, headPodSpecCustomResources), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + }, + }, + Tolerations: headTolerations, + }), }, workerK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, workerPodSpecCustomResources), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + }, + }, + Tolerations: workerTolerations, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + tolerations: headTolerations, + }, + workerPodAssertions: rayPodAssertions{ + affinity: &corev1.Affinity{}, + resources: resourceRequirements, + tolerations: workerTolerations, }, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: expectedHeadResources, - expectedWorkerResources: expectedWorkerResources, }, { - name: "custom runtime class name", - taskResources: resourceRequirements, - expectedSubmitterResources: resourceRequirements, - expectedHeadResources: resourceRequirements, - expectedWorkerResources: resourceRequirements, + name: "custom affinity", headK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, headPodSpecCustomRuntimeClass), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Affinity: headAffinity, + }), }, workerK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, workerPodSpecCustomRuntimeClass), + PodSpec: transformStructToStructPB(t, &corev1.PodSpec{ + Affinity: workerAffinity, + }), + }, + headPodAssertions: rayPodAssertions{ + affinity: headAffinity, + resources: resourceRequirements, + }, + workerPodAssertions: rayPodAssertions{ + affinity: workerAffinity, + resources: resourceRequirements, }, - expectedHeadRuntimeClassName: &nvidiaRuntimeClassName, - expectedWorkerRuntimeClassName: &nvidiaRuntimeClassName, }, } @@ -545,7 +673,7 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { } taskTemplate := dummyRayTaskTemplate("ray-id", rayJobInput) - taskContext := dummyRayTaskContext(taskTemplate, p.taskResources, nil, "", serviceAccount) + taskContext := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -553,29 +681,33 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { rayJob, ok := r.(*rayv1.RayJob) assert.True(t, ok) - submitterPodResources := rayJob.Spec.SubmitterPodTemplate.Spec.Containers[0].Resources - assert.EqualValues(t, - p.expectedSubmitterResources, - &submitterPodResources, - ) - headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec headPodResources := headPodSpec.Containers[0].Resources - assert.EqualValues(t, - p.expectedHeadResources, - &headPodResources, - ) + if p.headPodAssertions.resources != nil { + assert.EqualValues(t, + *p.headPodAssertions.resources, + headPodResources, + ) + } - assert.EqualValues(t, p.expectedHeadRuntimeClassName, headPodSpec.RuntimeClassName) + assert.EqualValues(t, p.headPodAssertions.runtimeClassName, headPodSpec.RuntimeClassName) + assert.EqualValues(t, p.headPodAssertions.tolerations, headPodSpec.Tolerations) + assert.EqualValues(t, p.headPodAssertions.affinity, headPodSpec.Affinity) for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs { workerPodSpec := workerGroupSpec.Template.Spec workerPodResources := workerPodSpec.Containers[0].Resources - assert.EqualValues(t, - p.expectedWorkerResources, - &workerPodResources, - ) - assert.EqualValues(t, p.expectedWorkerRuntimeClassName, workerPodSpec.RuntimeClassName) + + if p.workerPodAssertions.resources != nil { + assert.EqualValues(t, + *p.workerPodAssertions.resources, + workerPodResources, + ) + } + + assert.EqualValues(t, p.workerPodAssertions.runtimeClassName, workerPodSpec.RuntimeClassName) + assert.EqualValues(t, p.workerPodAssertions.tolerations, workerPodSpec.Tolerations) + assert.EqualValues(t, p.workerPodAssertions.affinity, workerPodSpec.Affinity) } }) }