From 49a2dd63d6df4ad3e15dc078895f007df8bd3cd6 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Mon, 24 Feb 2025 14:37:29 -0800 Subject: [PATCH] Rework MergePodSpecs logic (#6262) * Rework pod spec merge Signed-off-by: Jason Parraga * Rework again so semantics are clearer about templates vs overlays Signed-off-by: Jason Parraga * Rework two use cases into two methods Signed-off-by: Jason Parraga * Add docs Signed-off-by: Jason Parraga * lint-fix Signed-off-by: Jason Parraga * Add unit test coverage Signed-off-by: Jason Parraga * Add unit tests for overlay Signed-off-by: Jason Parraga --------- Signed-off-by: Jason Parraga --- .../pluginmachinery/flytek8s/pod_helper.go | 161 ++++-- .../flytek8s/pod_helper_test.go | 516 ++++++++++++++---- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 14 +- .../go/tasks/plugins/k8s/spark/spark.go | 4 +- 4 files changed, 542 insertions(+), 153 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index c9b609eb10..cbfe14c87d 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -670,8 +670,9 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio return podSpec, objectMeta, nil } - // merge podSpec with podTemplate - mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName) + // merge podTemplate onto podSpec + templateSpec := &podTemplate.Template.Spec + mergedPodSpec, err := MergeBasePodSpecOntoTemplate(templateSpec, podSpec, primaryContainerName, primaryInitContainerName) if err != nil { return nil, nil, err } @@ -685,50 +686,54 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio return mergedPodSpec, mergedObjectMeta, nil } -// MergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values -// set by the first PodSpec are overwritten by the second in the return value. Additionally, this function applies -// container-level configuration from the basePodSpec. -func MergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { - if basePodSpec == nil || podSpec == nil { - return nil, errors.New("neither the basePodSpec or the podSpec can be nil") +// MergeBasePodSpecOntoTemplate merges a base pod spec onto a template pod spec. The template pod spec has some +// magic values that allow users to specify templates that target all containers and primary containers. Aside from +// magic values this method will merge containers that have matching names. +func MergeBasePodSpecOntoTemplate(templatePodSpec *v1.PodSpec, basePodSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { + if templatePodSpec == nil || basePodSpec == nil { + return nil, errors.New("neither the templatePodSpec or the basePodSpec can be nil") } - // extract defaultContainerTemplate and primaryContainerTemplate + // extract primaryContainerTemplate. The base should always contain the primary container. var defaultContainerTemplate, primaryContainerTemplate *v1.Container - for i := 0; i < len(basePodSpec.Containers); i++ { - if basePodSpec.Containers[i].Name == defaultContainerTemplateName { - defaultContainerTemplate = &basePodSpec.Containers[i] - } else if basePodSpec.Containers[i].Name == primaryContainerName { - primaryContainerTemplate = &basePodSpec.Containers[i] + + // extract default container template + for i := 0; i < len(templatePodSpec.Containers); i++ { + if templatePodSpec.Containers[i].Name == defaultContainerTemplateName { + defaultContainerTemplate = &templatePodSpec.Containers[i] + } else if templatePodSpec.Containers[i].Name == primaryContainerTemplateName { + primaryContainerTemplate = &templatePodSpec.Containers[i] } } - // extract defaultInitContainerTemplate and primaryInitContainerTemplate + // extract primaryInitContainerTemplate. The base should always contain the primary container. var defaultInitContainerTemplate, primaryInitContainerTemplate *v1.Container - for i := 0; i < len(basePodSpec.InitContainers); i++ { - if basePodSpec.InitContainers[i].Name == defaultInitContainerTemplateName { - defaultInitContainerTemplate = &basePodSpec.InitContainers[i] - } else if basePodSpec.InitContainers[i].Name == primaryInitContainerName { - primaryInitContainerTemplate = &basePodSpec.InitContainers[i] + + // extract defaultInitContainerTemplate + for i := 0; i < len(templatePodSpec.InitContainers); i++ { + if templatePodSpec.InitContainers[i].Name == defaultInitContainerTemplateName { + defaultInitContainerTemplate = &templatePodSpec.InitContainers[i] + } else if templatePodSpec.InitContainers[i].Name == primaryInitContainerTemplateName { + primaryInitContainerTemplate = &templatePodSpec.InitContainers[i] } } - // merge PodTemplate PodSpec with podSpec - var mergedPodSpec *v1.PodSpec = basePodSpec.DeepCopy() - if err := mergo.Merge(mergedPodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice); err != nil { + // Merge base into template + mergedPodSpec := templatePodSpec.DeepCopy() + if err := mergo.Merge(mergedPodSpec, basePodSpec, mergo.WithOverride, mergo.WithAppendSlice); err != nil { return nil, err } // merge PodTemplate containers var mergedContainers []v1.Container - for _, container := range podSpec.Containers { + for _, container := range basePodSpec.Containers { // if applicable start with defaultContainerTemplate var mergedContainer *v1.Container if defaultContainerTemplate != nil { mergedContainer = defaultContainerTemplate.DeepCopy() } - // if applicable merge with primaryContainerTemplate + // If this is a primary container handle the template if container.Name == primaryContainerName && primaryContainerTemplate != nil { if mergedContainer == nil { mergedContainer = primaryContainerTemplate.DeepCopy() @@ -740,35 +745,48 @@ func MergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContaine } } - // if applicable merge with existing container + // Check for any name matching template containers + for _, templateContainer := range templatePodSpec.Containers { + if templateContainer.Name != container.Name { + continue + } + + if mergedContainer == nil { + mergedContainer = &templateContainer + } else { + err := mergo.Merge(mergedContainer, templateContainer, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + } + } + + // Merge in the base container if mergedContainer == nil { - mergedContainers = append(mergedContainers, container) + mergedContainer = container.DeepCopy() } else { err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { return nil, err } - - mergedContainers = append(mergedContainers, *mergedContainer) } - } - if mergedContainers == nil { - mergedContainers = basePodSpec.Containers + mergedContainers = append(mergedContainers, *mergedContainer) + } mergedPodSpec.Containers = mergedContainers // merge PodTemplate init containers var mergedInitContainers []v1.Container - for _, initContainer := range podSpec.InitContainers { + for _, initContainer := range basePodSpec.InitContainers { // if applicable start with defaultContainerTemplate var mergedInitContainer *v1.Container if defaultInitContainerTemplate != nil { mergedInitContainer = defaultInitContainerTemplate.DeepCopy() } - // if applicable merge with primaryInitContainerTemplate + // If this is a primary init container handle the template if initContainer.Name == primaryInitContainerName && primaryInitContainerTemplate != nil { if mergedInitContainer == nil { mergedInitContainer = primaryInitContainerTemplate.DeepCopy() @@ -780,21 +798,86 @@ func MergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContaine } } - // if applicable merge with existing init initContainer + // Check for any name matching template containers + for _, templateInitContainer := range templatePodSpec.InitContainers { + if templateInitContainer.Name != initContainer.Name { + continue + } + + if mergedInitContainer == nil { + mergedInitContainer = &templateInitContainer + } else { + err := mergo.Merge(mergedInitContainer, templateInitContainer, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + } + } + + // Merge in the base init container if mergedInitContainer == nil { - mergedInitContainers = append(mergedInitContainers, initContainer) + mergedInitContainer = initContainer.DeepCopy() } else { err := mergo.Merge(mergedInitContainer, initContainer, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { return nil, err } + } + + mergedInitContainers = append(mergedInitContainers, *mergedInitContainer) + } - mergedInitContainers = append(mergedInitContainers, *mergedInitContainer) + mergedPodSpec.InitContainers = mergedInitContainers + + return mergedPodSpec, nil +} + +// MergeOverlayPodSpecOntoBase merges a customized pod spec onto a base pod spec. At a container level it will +// merge containers that have matching names. +func MergeOverlayPodSpecOntoBase(basePodSpec *v1.PodSpec, overlayPodSpec *v1.PodSpec) (*v1.PodSpec, error) { + if basePodSpec == nil || overlayPodSpec == nil { + return nil, errors.New("neither the basePodSpec or the overlayPodSpec can be nil") + } + + mergedPodSpec := basePodSpec.DeepCopy() + if err := mergo.Merge(mergedPodSpec, overlayPodSpec, mergo.WithOverride, mergo.WithAppendSlice); err != nil { + return nil, err + } + + // merge PodTemplate containers + var mergedContainers []v1.Container + for _, container := range basePodSpec.Containers { + + mergedContainer := container.DeepCopy() + + for _, overlayContainer := range overlayPodSpec.Containers { + if mergedContainer.Name == overlayContainer.Name { + err := mergo.Merge(mergedContainer, overlayContainer, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + } } + mergedContainers = append(mergedContainers, *mergedContainer) } - if mergedInitContainers == nil { - mergedInitContainers = basePodSpec.InitContainers + mergedPodSpec.Containers = mergedContainers + + // merge PodTemplate init containers + var mergedInitContainers []v1.Container + for _, initContainer := range basePodSpec.InitContainers { + + mergedInitContainer := initContainer.DeepCopy() + + for _, overlayInitContainer := range overlayPodSpec.InitContainers { + if mergedInitContainer.Name == overlayInitContainer.Name { + err := mergo.Merge(mergedInitContainer, overlayInitContainer, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + } + } + mergedInitContainers = append(mergedInitContainers, *mergedInitContainer) } mergedPodSpec.InitContainers = mergedInitContainers diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 6fab4ce455..556a7b3534 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -3,6 +3,7 @@ package flytek8s import ( "context" "encoding/json" + "errors" "fmt" "io/ioutil" "path/filepath" @@ -2014,12 +2015,12 @@ func TestMergeWithBasePodTemplate(t *testing.T) { t.Run("BasePodTemplateExists", func(t *testing.T) { primaryContainerTemplate := v1.Container{ - Name: "foo", + Name: primaryContainerTemplateName, TerminationMessagePath: "/dev/primary-termination-log", } primaryInitContainerTemplate := v1.Container{ - Name: "foo-init", + Name: primaryInitContainerTemplateName, TerminationMessagePath: "/dev/primary-init-termination-log", } @@ -2087,138 +2088,449 @@ func TestMergeWithBasePodTemplate(t *testing.T) { }) } -func TestMergePodSpecs(t *testing.T) { - var priority int32 = 1 +func TestMergeBasePodSpecsOntoTemplate(t *testing.T) { - podSpec1, _ := MergePodSpecs(nil, nil, "foo", "foo-init") - assert.Nil(t, podSpec1) + baseContainer1 := v1.Container{ + Name: "task-1", + Image: "task-image", + } - podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init") - assert.Nil(t, podSpec2) + baseContainer2 := v1.Container{ + Name: "task-2", + Image: "task-image", + } - podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init") - assert.Nil(t, podSpec3) + initContainer1 := v1.Container{ + Name: "task-init-1", + Image: "task-init-image", + } - podSpec := v1.PodSpec{ - Containers: []v1.Container{ - v1.Container{ - Name: "primary", - VolumeMounts: []v1.VolumeMount{ + initContainer2 := v1.Container{ + Name: "task-init-2", + Image: "task-init-image", + } + + tests := []struct { + name string + templatePodSpec *v1.PodSpec + basePodSpec *v1.PodSpec + primaryContainerName string + primaryInitContainerName string + expectedResult *v1.PodSpec + expectedError error + }{ + { + name: "nil template", + templatePodSpec: nil, + basePodSpec: &v1.PodSpec{}, + expectedError: errors.New("neither the templatePodSpec or the basePodSpec can be nil"), + }, + { + name: "nil base", + templatePodSpec: &v1.PodSpec{}, + basePodSpec: nil, + expectedError: errors.New("neither the templatePodSpec or the basePodSpec can be nil"), + }, + { + name: "nil template and base", + templatePodSpec: nil, + basePodSpec: nil, + expectedError: errors.New("neither the templatePodSpec or the basePodSpec can be nil"), + }, + { + name: "template and base with no overlap", + templatePodSpec: &v1.PodSpec{ + SchedulerName: "templateScheduler", + }, + basePodSpec: &v1.PodSpec{ + ServiceAccountName: "baseServiceAccount", + }, + expectedResult: &v1.PodSpec{ + SchedulerName: "templateScheduler", + ServiceAccountName: "baseServiceAccount", + }, + }, + { + name: "template and base with overlap", + templatePodSpec: &v1.PodSpec{ + SchedulerName: "templateScheduler", + }, + basePodSpec: &v1.PodSpec{ + SchedulerName: "baseScheduler", + ServiceAccountName: "baseServiceAccount", + }, + expectedResult: &v1.PodSpec{ + SchedulerName: "baseScheduler", + ServiceAccountName: "baseServiceAccount", + }, + }, + { + name: "template with default containers and base with no containers", + templatePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ { - Name: "nccl", - MountPath: "abc", + Name: "default", + Image: "default-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "default-init", + Image: "default-init-image", }, }, }, - v1.Container{ - Name: "bar", + basePodSpec: &v1.PodSpec{ + SchedulerName: "baseScheduler", + }, + expectedResult: &v1.PodSpec{ + SchedulerName: "baseScheduler", }, }, - InitContainers: []v1.Container{ - v1.Container{ - Name: "primary-init", - VolumeMounts: []v1.VolumeMount{ + { + name: "template with no default containers and base containers", + templatePodSpec: &v1.PodSpec{}, + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1}, + InitContainers: []v1.Container{initContainer1}, + SchedulerName: "baseScheduler", + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1}, + InitContainers: []v1.Container{initContainer1}, + SchedulerName: "baseScheduler", + }, + }, + { + name: "template and base with matching containers", + templatePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "default-task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + }, + InitContainers: []v1.Container{ { - Name: "nccl", - MountPath: "abc", + Name: "task-init-1", + Image: "default-task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", }, }, }, - v1.Container{ - Name: "bar-init", + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1}, + InitContainers: []v1.Container{initContainer1}, + SchedulerName: "baseScheduler", + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + }, + SchedulerName: "baseScheduler", }, }, - NodeSelector: map[string]string{ - "baz": "bar", + { + name: "template and base with no matching containers", + templatePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "not-matching", + Image: "default-task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + }, + InitContainers: []v1.Container{ + { + Name: "not-matching-init", + Image: "default-task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + }, + }, + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1}, + InitContainers: []v1.Container{initContainer1}, + SchedulerName: "baseScheduler", + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1}, + InitContainers: []v1.Container{initContainer1}, + SchedulerName: "baseScheduler", + }, }, - Priority: &priority, - SchedulerName: "overrideScheduler", - Tolerations: []v1.Toleration{ - v1.Toleration{ - Key: "bar", + { + name: "template with default containers and base with containers", + templatePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "default", + Image: "default-task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + }, + InitContainers: []v1.Container{ + { + Name: "default-init", + Image: "default-task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + }, }, - v1.Toleration{ - Key: "baz", + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1, baseContainer2}, + InitContainers: []v1.Container{initContainer1, initContainer2}, + SchedulerName: "baseScheduler", + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + { + Name: "task-2", + Image: "task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + { + Name: "task-init-2", + Image: "task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + }, + SchedulerName: "baseScheduler", }, }, + { + name: "template with primary containers and base with containers", + templatePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "primary", + Image: "default-task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + }, + InitContainers: []v1.Container{ + { + Name: "primary-init", + Image: "default-task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + }, + }, + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{baseContainer1, baseContainer2}, + InitContainers: []v1.Container{initContainer1, initContainer2}, + SchedulerName: "baseScheduler", + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "task-image", + TerminationMessagePath: "/dev/template-termination-log", + }, + baseContainer2, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "task-init-image", + TerminationMessagePath: "/dev/template-init-termination-log", + }, + initContainer2, + }, + SchedulerName: "baseScheduler", + }, + primaryContainerName: "task-1", + primaryInitContainerName: "task-init-1", + }, } - defaultContainerTemplate := v1.Container{ - Name: defaultContainerTemplateName, - TerminationMessagePath: "/dev/default-termination-log", - } - - primaryContainerTemplate := v1.Container{ - Name: primaryContainerTemplateName, - TerminationMessagePath: "/dev/primary-termination-log", - } - - defaultInitContainerTemplate := v1.Container{ - Name: defaultInitContainerTemplateName, - TerminationMessagePath: "/dev/default-init-termination-log", + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, mergeErr := MergeBasePodSpecOntoTemplate(tt.templatePodSpec, tt.basePodSpec, tt.primaryContainerName, tt.primaryInitContainerName) + assert.Equal(t, tt.expectedResult, result) + assert.Equal(t, tt.expectedError, mergeErr) + }) } +} - primaryInitContainerTemplate := v1.Container{ - Name: primaryInitContainerTemplateName, - TerminationMessagePath: "/dev/primary-init-termination-log", - } +func TestMergeOverlayPodSpecOntoBase(t *testing.T) { - podTemplateSpec := v1.PodSpec{ - Containers: []v1.Container{ - defaultContainerTemplate, - primaryContainerTemplate, + tests := []struct { + name string + basePodSpec *v1.PodSpec + overlayPodSpec *v1.PodSpec + expectedResult *v1.PodSpec + expectedError error + }{ + { + name: "nil overlay", + basePodSpec: &v1.PodSpec{}, + overlayPodSpec: nil, + expectedError: errors.New("neither the basePodSpec or the overlayPodSpec can be nil"), }, - InitContainers: []v1.Container{ - defaultInitContainerTemplate, - primaryInitContainerTemplate, + { + name: "nil base", + basePodSpec: nil, + overlayPodSpec: &v1.PodSpec{}, + expectedError: errors.New("neither the basePodSpec or the overlayPodSpec can be nil"), }, - HostNetwork: true, - NodeSelector: map[string]string{ - "foo": "bar", + { + name: "nil base and overlay", + basePodSpec: nil, + overlayPodSpec: nil, + expectedError: errors.New("neither the basePodSpec or the overlayPodSpec can be nil"), }, - SchedulerName: "defaultScheduler", - Tolerations: []v1.Toleration{ - v1.Toleration{ - Key: "foo", + { + name: "base and overlay no overlap", + basePodSpec: &v1.PodSpec{ + SchedulerName: "baseScheduler", + }, + overlayPodSpec: &v1.PodSpec{ + ServiceAccountName: "overlayServiceAccount", + }, + expectedResult: &v1.PodSpec{ + SchedulerName: "baseScheduler", + ServiceAccountName: "overlayServiceAccount", + }, + }, + { + name: "template and base with overlap", + basePodSpec: &v1.PodSpec{ + SchedulerName: "baseScheduler", + }, + overlayPodSpec: &v1.PodSpec{ + SchedulerName: "overlayScheduler", + ServiceAccountName: "overlayServiceAccount", + }, + expectedResult: &v1.PodSpec{ + SchedulerName: "overlayScheduler", + ServiceAccountName: "overlayServiceAccount", + }, + }, + { + name: "template and base with matching containers", + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "task-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "task-init-image", + }, + }, + }, + overlayPodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "overlay-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "overlay-init-image", + }, + }, + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "overlay-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "overlay-init-image", + }, + }, + }, + }, + { + name: "base and overlay with no matching containers", + basePodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "task-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "task-init-image", + }, + }, + }, + overlayPodSpec: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "overlay-1", + Image: "overlay-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "overlay-init-1", + Image: "overlay-init-image", + }, + }, + }, + expectedResult: &v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "task-1", + Image: "task-image", + }, + }, + InitContainers: []v1.Container{ + { + Name: "task-init-1", + Image: "task-init-image", + }, + }, }, }, } - mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init") - assert.Nil(t, err) - - // validate a PodTemplate-only field - assert.Equal(t, podTemplateSpec.HostNetwork, mergedPodSpec.HostNetwork) - // validate a PodSpec-only field - assert.Equal(t, podSpec.Priority, mergedPodSpec.Priority) - // validate an overwritten PodTemplate field - assert.Equal(t, podSpec.SchedulerName, mergedPodSpec.SchedulerName) - // validate a merged map - assert.Equal(t, len(podTemplateSpec.NodeSelector)+len(podSpec.NodeSelector), len(mergedPodSpec.NodeSelector)) - // validate an appended array - assert.Equal(t, len(podTemplateSpec.Tolerations)+len(podSpec.Tolerations), len(mergedPodSpec.Tolerations)) - - // validate primary container - primaryContainer := mergedPodSpec.Containers[0] - assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) - assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) - assert.Equal(t, 1, len(primaryContainer.VolumeMounts)) - - // validate default container - defaultContainer := mergedPodSpec.Containers[1] - assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name) - assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath) - - // validate primary init container - primaryInitContainer := mergedPodSpec.InitContainers[0] - assert.Equal(t, podSpec.InitContainers[0].Name, primaryInitContainer.Name) - assert.Equal(t, primaryInitContainerTemplate.TerminationMessagePath, primaryInitContainer.TerminationMessagePath) - assert.Equal(t, 1, len(primaryInitContainer.VolumeMounts)) - - // validate default init container - defaultInitContainer := mergedPodSpec.InitContainers[1] - assert.Equal(t, podSpec.InitContainers[1].Name, defaultInitContainer.Name) - assert.Equal(t, defaultInitContainerTemplate.TerminationMessagePath, defaultInitContainer.TerminationMessagePath) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, mergeErr := MergeOverlayPodSpecOntoBase(tt.basePodSpec, tt.overlayPodSpec) + assert.Equal(t, tt.expectedResult, result) + assert.Equal(t, tt.expectedError, mergeErr) + }) + } } func TestAddFlyteCustomizationsToContainer_SetConsoleUrl(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index db5c8788de..defd7c1e85 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -370,7 +370,7 @@ func buildHeadPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodSpe // Inject a sidecar for capturing and exposing Ray job logs injectLogsSidecar(primaryContainer, basePodSpec) - basePodSpec, err := mergeCustomPodSpec(primaryContainer, basePodSpec, spec.GetK8SPod()) + basePodSpec, err := mergeCustomPodSpec(basePodSpec, spec.GetK8SPod()) if err != nil { return v1.PodTemplateSpec{}, err } @@ -500,7 +500,7 @@ func buildWorkerPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodS } primaryContainer.Ports = append(primaryContainer.Ports, ports...) - basePodSpec, err := mergeCustomPodSpec(primaryContainer, basePodSpec, spec.GetK8SPod()) + basePodSpec, err := mergeCustomPodSpec(basePodSpec, spec.GetK8SPod()) if err != nil { return v1.PodTemplateSpec{}, err } @@ -518,7 +518,7 @@ func buildWorkerPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodS } // Merges a ray head/worker node custom pod specs onto task's generated pod spec -func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8sPod *core.K8SPod) (*v1.PodSpec, error) { +func mergeCustomPodSpec(podSpec *v1.PodSpec, k8sPod *core.K8SPod) (*v1.PodSpec, error) { if k8sPod == nil { return podSpec, nil } @@ -535,13 +535,7 @@ func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8s "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.GetPodSpec(), err.Error()) } - err = utils.UnmarshalStructToObj(k8sPod.GetPodSpec(), &customPodSpec) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, - "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.GetPodSpec(), err.Error()) - } - - podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainer.Name, "") + podSpec, err = flytek8s.MergeOverlayPodSpecOntoBase(podSpec, customPodSpec) if err != nil { return nil, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 14c5ef9c8d..56eb746f15 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -209,7 +209,7 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont "Unable to unmarshal driver pod spec [%v], Err: [%v]", driverPod.GetPodSpec(), err.Error()) } - podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "") + podSpec, err = flytek8s.MergeOverlayPodSpecOntoBase(podSpec, customPodSpec) if err != nil { return nil, err } @@ -262,7 +262,7 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo "Unable to unmarshal executor pod spec [%v], Err: [%v]", executorPod.GetPodSpec(), err.Error()) } - podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "") + podSpec, err = flytek8s.MergeOverlayPodSpecOntoBase(podSpec, customPodSpec) if err != nil { return nil, err }