diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index dd93ef8933..00f8778a66 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -2221,6 +2221,117 @@ func TestMergePodSpecs(t *testing.T) { assert.Equal(t, defaultInitContainerTemplate.TerminationMessagePath, defaultInitContainer.TerminationMessagePath) } +func TestMergePodSpecsPrimaryContainerName(t *testing.T) { + + basePodSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "primary", + VolumeMounts: []v1.VolumeMount{ + { + Name: "nccl", + MountPath: "abc", + }, + }, + Env: []v1.EnvVar{ + {Name: "EnvVar", Value: "EnvVal"}, + }, + }, + v1.Container{ + Name: "bar", + }, + }, + InitContainers: []v1.Container{ + v1.Container{ + Name: "primary-init", + VolumeMounts: []v1.VolumeMount{ + { + Name: "nccl", + MountPath: "abc", + }, + }, + }, + v1.Container{ + Name: "bar-init", + }, + }, + Tolerations: []v1.Toleration{ + v1.Toleration{ + Key: "bar", + }, + v1.Toleration{ + Key: "baz", + }, + }, + } + + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "primary-new", + VolumeMounts: []v1.VolumeMount{ + { + Name: "nccl-new", + MountPath: "abc-new", + }, + }, + Env: []v1.EnvVar{ + {Name: "EnvVar", Value: "EnvVal"}, + }, + }, + v1.Container{ + Name: "bar-new", + }, + }, + InitContainers: []v1.Container{ + v1.Container{ + Name: "primary-init-new", + VolumeMounts: []v1.VolumeMount{ + { + Name: "nccl-new", + MountPath: "abc-new", + }, + }, + }, + v1.Container{ + Name: "bar-init-new", + }, + }, + Tolerations: []v1.Toleration{ + v1.Toleration{ + Key: "bar-new", + }, + v1.Toleration{ + Key: "baz-new", + }, + }, + } + + // primary (init) container name different from basePodSpec + mergedPodSpec, err := MergePodSpecs(&basePodSpec, &podSpec, "primary-new", "primary-init-new") + assert.Nil(t, err) + + // validate an appended array + assert.Equal(t, len(basePodSpec.Tolerations)+len(podSpec.Tolerations), len(mergedPodSpec.Tolerations)) + + // validate primary container, should exclude the one in basePodSpec + primaryContainer := mergedPodSpec.Containers[0] + assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) + // this will also contain the one from basePodSpec + assert.Equal(t, 2, len(primaryContainer.VolumeMounts)) + assert.Equal(t, 2, len(primaryContainer.Env)) + + // WARN: the other container in podSpec is also included + assert.Equal(t, 2, len(mergedPodSpec.Containers)) + assert.Equal(t, mergedPodSpec.Containers[1].Name, podSpec.Containers[1].Name) + + // validate primary init container + primaryInitContainer := mergedPodSpec.InitContainers[0] + assert.Equal(t, podSpec.InitContainers[0].Name, primaryInitContainer.Name) + // this will also contain the one from basePodSpec + assert.Equal(t, 2, len(primaryInitContainer.VolumeMounts)) +} + func TestAddFlyteCustomizationsToContainer_SetConsoleUrl(t *testing.T) { tests := []struct { name string