From a57868581c4edba7defb5fc0aa005f8caec47557 Mon Sep 17 00:00:00 2001 From: tariq-hasan Date: Mon, 24 Jun 2024 20:26:03 -0400 Subject: [PATCH] Introduced error constants and replaced reflect with cmp --- .../experiment/manifest/generator.go | 27 +- .../experiment/manifest/generator_test.go | 141 +++---- .../trial/util/job_util_test.go | 59 ++- .../v1beta1/goptuna/converter_test.go | 86 ++-- pkg/webhook/v1beta1/pod/inject_webhook.go | 16 +- .../v1beta1/pod/inject_webhook_test.go | 385 +++++++++--------- pkg/webhook/v1beta1/pod/utils.go | 6 +- 7 files changed, 341 insertions(+), 379 deletions(-) diff --git a/pkg/controller.v1beta1/experiment/manifest/generator.go b/pkg/controller.v1beta1/experiment/manifest/generator.go index a41cc2f0c89..c3d0a2a14bc 100644 --- a/pkg/controller.v1beta1/experiment/manifest/generator.go +++ b/pkg/controller.v1beta1/experiment/manifest/generator.go @@ -17,6 +17,7 @@ limitations under the License. package manifest import ( + "errors" "fmt" "regexp" "strings" @@ -33,6 +34,15 @@ import ( "github.com/kubeflow/katib/pkg/util/v1beta1/katibconfig" ) +var ( + errConfigMapNotFound = errors.New("configMap not found") + errConvertStringToUnstructuredFailed = errors.New("failed to convert string to unstructured") + errConvertUnstructuredToStringFailed = errors.New("failed to convert unstructured to string") + errParamNotFoundInParameterAssignment = errors.New("unable to find non-meta parameter from TrialParameters in ParameterAssignment") + errParamNotFoundInTrialParameters = errors.New("unable to find parameter from ParameterAssignment in TrialParameters") + errTrialTemplateNotFound = errors.New("unable to find trial template in ConfigMap") +) + // Generator is the type for manifests Generator. type Generator interface { InjectClient(c client.Client) @@ -86,7 +96,7 @@ func (g *DefaultGenerator) GetRunSpecWithHyperParameters(experiment *experiments // Convert Trial template to unstructured runSpec, err := util.ConvertStringToUnstructured(replacedTemplate) if err != nil { - return nil, fmt.Errorf("ConvertStringToUnstructured failed: %v", err) + return nil, fmt.Errorf("%w: %w", errConvertStringToUnstructuredFailed, err) } // Set name and namespace for Run Spec @@ -108,7 +118,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi if trialSpec == nil { trialSpec, err = util.ConvertStringToUnstructured(trialTemplate) if err != nil { - return "", fmt.Errorf("ConvertStringToUnstructured failed: %v", err) + return "", fmt.Errorf("%w: %w", errConvertStringToUnstructuredFailed, err) } } @@ -131,7 +141,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi nonMetaParamCount += 1 continue } else { - return "", fmt.Errorf("Unable to find parameter: %v in parameter assignment %v", param.Reference, assignmentsMap) + return "", fmt.Errorf("%w: parameter: %v, parameter assignment: %v", errParamNotFoundInParameterAssignment, param.Reference, assignmentsMap) } } metaRefKey = sub[1] @@ -172,9 +182,10 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi } } - // Number of parameters must be equal + // Number of assignment parameters must be equal to the number of non-meta trial parameters + // i.e. all parameters in ParameterAssignment must be in TrialParameters if len(assignments) != nonMetaParamCount { - return "", fmt.Errorf("Number of TrialAssignment: %v != number of nonMetaTrialParameters in TrialSpec: %v", len(assignments), nonMetaParamCount) + return "", fmt.Errorf("%w: parameter assignments: %v, non-meta trial parameter count: %v", errParamNotFoundInTrialParameters, assignments, nonMetaParamCount) } // Replacing placeholders with parameter values @@ -194,7 +205,7 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim if trialSource.TrialSpec != nil { trialTemplateString, err = util.ConvertUnstructuredToString(trialSource.TrialSpec) if err != nil { - return "", fmt.Errorf("ConvertUnstructuredToString failed: %v", err) + return "", fmt.Errorf("%w: %w", errConvertUnstructuredToStringFailed, err) } } else { configMapNS := trialSource.ConfigMap.ConfigMapNamespace @@ -202,12 +213,12 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim templatePath := trialSource.ConfigMap.TemplatePath configMap, err := g.client.GetConfigMap(configMapName, configMapNS) if err != nil { - return "", fmt.Errorf("GetConfigMap failed: %v", err) + return "", fmt.Errorf("%w: %w", errConfigMapNotFound, err) } var ok bool trialTemplateString, ok = configMap[templatePath] if !ok { - return "", fmt.Errorf("TemplatePath: %v not found in configMap: %v", templatePath, configMap) + return "", fmt.Errorf("%w: TemplatePath: %v, ConfigMap: %v", errTrialTemplateNotFound, templatePath, configMap) } } diff --git a/pkg/controller.v1beta1/experiment/manifest/generator_test.go b/pkg/controller.v1beta1/experiment/manifest/generator_test.go index dabd2631063..6059d4d7d96 100644 --- a/pkg/controller.v1beta1/experiment/manifest/generator_test.go +++ b/pkg/controller.v1beta1/experiment/manifest/generator_test.go @@ -17,11 +17,11 @@ limitations under the License. package manifest import ( - "errors" "math" - "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "go.uber.org/mock/gomock" batchv1 "k8s.io/api/batch/v1" v1 "k8s.io/api/core/v1" @@ -88,23 +88,18 @@ func TestGetRunSpecWithHP(t *testing.T) { t.Errorf("ConvertObjectToUnstructured failed: %v", err) } - tcs := []struct { - instance *experimentsv1beta1.Experiment - parameterAssignments []commonapiv1beta1.ParameterAssignment - expectedRunSpec *unstructured.Unstructured - err bool - testDescription string + cases := map[string]struct { + instance *experimentsv1beta1.Experiment + parameterAssignments []commonapiv1beta1.ParameterAssignment + wantRunSpecWithHyperParameters *unstructured.Unstructured + wantError error }{ - // Valid run - { - instance: newFakeInstance(), - parameterAssignments: newFakeParameterAssignment(), - expectedRunSpec: expectedRunSpec, - err: false, - testDescription: "Run with valid parameters", + "Run with valid parameters": { + instance: newFakeInstance(), + parameterAssignments: newFakeParameterAssignment(), + wantRunSpecWithHyperParameters: expectedRunSpec, }, - // Invalid JSON in unstructured - { + "Invalid JSON in Unstructured Trial template": { instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() trialSpec := i.Spec.TrialTemplate.TrialSource.TrialSpec @@ -114,48 +109,45 @@ func TestGetRunSpecWithHP(t *testing.T) { return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid JSON in Trial template", + wantError: errConvertUnstructuredToStringFailed, }, - // len(parameterAssignment) != len(trialParameters) - { + "Non-meta parameter from TrialParameters not found in ParameterAssignment": { instance: newFakeInstance(), parameterAssignments: func() []commonapiv1beta1.ParameterAssignment { pa := newFakeParameterAssignment() - pa = pa[1:] + pa[0] = commonapiv1beta1.ParameterAssignment{ + Name: "invalid-name", + Value: "invalid-value", + } return pa }(), - err: true, - testDescription: "Number of parameter assignments is not equal to number of Trial parameters", + wantError: errParamNotFoundInParameterAssignment, }, - // Parameter from assignments not found in Trial parameters - { + // case in which the lengths of trial parameters and parameter assignments are different + "Parameter from ParameterAssignment not found in TrialParameters": { instance: newFakeInstance(), parameterAssignments: func() []commonapiv1beta1.ParameterAssignment { pa := newFakeParameterAssignment() - pa[0] = commonapiv1beta1.ParameterAssignment{ - Name: "invalid-name", - Value: "invalid-value", - } + pa = append(pa, commonapiv1beta1.ParameterAssignment{ + Name: "extra-name", + Value: "extra-value", + }) return pa }(), - err: true, - testDescription: "Trial parameters don't have parameter from assignments", + wantError: errParamNotFoundInTrialParameters, }, } - for _, tc := range tcs { - actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) - - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(tc.expectedRunSpec, actualRunSpec) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedRunSpec.Object, actualRunSpec.Object) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 { + t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) } - } + }) } } @@ -208,7 +200,7 @@ spec: map[string]string{templatePath: trialSpec}, nil) invalidConfigMapName := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( - nil, errors.New("Unable to get ConfigMap")) + nil, errConfigMapNotFound) validGetConfigMap3 := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( map[string]string{templatePath: trialSpec}, nil) @@ -244,19 +236,18 @@ spec: - "--momentum=0.9"` expectedRunSpec, err := util.ConvertStringToUnstructured(expectedStr) - if err != nil { - t.Errorf("ConvertStringToUnstructured failed: %v", err) + if diff := cmp.Diff(nil, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("ConvertStringToUnstructured failed (-want,+got):\n%s", diff) } - tcs := []struct { - instance *experimentsv1beta1.Experiment - parameterAssignments []commonapiv1beta1.ParameterAssignment - err bool - testDescription string + cases := map[string]struct { + instance *experimentsv1beta1.Experiment + parameterAssignments []commonapiv1beta1.ParameterAssignment + wantRunSpecWithHyperParameters *unstructured.Unstructured + wantError error }{ - // Valid run // validGetConfigMap1 case - { + "Run with valid parameters": { instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -268,13 +259,11 @@ spec: } return i }(), - parameterAssignments: newFakeParameterAssignment(), - err: false, - testDescription: "Run with valid parameters", + parameterAssignments: newFakeParameterAssignment(), + wantRunSpecWithHyperParameters: expectedRunSpec, }, - // Invalid ConfigMap name // invalidConfigMapName case - { + "Invalid ConfigMap name": { instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -285,12 +274,10 @@ spec: return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid ConfigMap name", + wantError: errConfigMapNotFound, }, - // Invalid template path in ConfigMap name // validGetConfigMap3 case - { + "Invalid template path in ConfigMap name": { instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -303,14 +290,12 @@ spec: return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid template path in ConfigMap", + wantError: errTrialTemplateNotFound, }, - // Invalid Trial template spec in ConfigMap + // invalidTemplate case // Trial template is a string in ConfigMap // Because of that, user can specify not valid unstructured template - // invalidTemplate case - { + "Invalid trial spec in ConfigMap": { instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -323,22 +308,20 @@ spec: return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid Trial spec in ConfigMap", + wantError: errConvertStringToUnstructuredFailed, }, } - for _, tc := range tcs { - actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(expectedRunSpec, actualRunSpec) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, expectedRunSpec.Object, actualRunSpec.Object) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 { + t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) } - } + }) } } diff --git a/pkg/controller.v1beta1/trial/util/job_util_test.go b/pkg/controller.v1beta1/trial/util/job_util_test.go index d2a018967c3..c1908144df9 100644 --- a/pkg/controller.v1beta1/trial/util/job_util_test.go +++ b/pkg/controller.v1beta1/trial/util/job_util_test.go @@ -17,7 +17,6 @@ limitations under the License. package util import ( - "reflect" "testing" batchv1 "k8s.io/api/batch/v1" @@ -25,6 +24,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" trialsv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/trials/v1beta1" "github.com/kubeflow/katib/pkg/controller.v1beta1/util" ) @@ -39,14 +40,13 @@ func TestGetDeployedJobStatus(t *testing.T) { successCondition := "status.conditions.#(type==\"Complete\")#|#(status==\"True\")#" failureCondition := "status.conditions.#(type==\"Failed\")#|#(status==\"True\")#" - tcs := []struct { - trial *trialsv1beta1.Trial - deployedJob *unstructured.Unstructured - expectedTrialJobStatus *TrialJobStatus - err bool - testDescription string + cases := map[string]struct { + trial *trialsv1beta1.Trial + deployedJob *unstructured.Unstructured + wantTrialJobStatus *TrialJobStatus + wantError error }{ - { + "Job status is running": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: func() *unstructured.Unstructured { job := newFakeJob() @@ -54,28 +54,24 @@ func TestGetDeployedJobStatus(t *testing.T) { job.Status.Conditions[1].Status = corev1.ConditionFalse return newFakeDeployedJob(job) }(), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobRunning, } }(), - err: false, - testDescription: "Job status is running", }, - { + "Job status is succeeded, reason and message must be returned": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: newFakeDeployedJob(newFakeJob()), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobSucceeded, Message: testMessage, Reason: testReason, } }(), - err: false, - testDescription: "Job status is succeeded, reason and message must be returned", }, - { + "Job status is failed, reason and message must be returned": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: func() *unstructured.Unstructured { job := newFakeJob() @@ -83,41 +79,35 @@ func TestGetDeployedJobStatus(t *testing.T) { job.Status.Conditions[1].Status = corev1.ConditionFalse return newFakeDeployedJob(job) }(), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobFailed, Message: testMessage, Reason: testReason, } }(), - err: false, - testDescription: "Job status is failed, reason and message must be returned", }, - { + "Job status is succeeded because status.succeeded = 1": { trial: newFakeTrial("status.[@this].#(succeeded==1)", failureCondition), deployedJob: newFakeDeployedJob(newFakeJob()), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobSucceeded, } }(), - err: false, - testDescription: "Job status is succeeded because status.succeeded = 1", }, } - for _, tc := range tcs { - actualTrialJobStatus, err := GetDeployedJobStatus(tc.trial, tc.deployedJob) - - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(tc.expectedTrialJobStatus, actualTrialJobStatus) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedTrialJobStatus, actualTrialJobStatus) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := GetDeployedJobStatus(tc.trial, tc.deployedJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetDeployedJobStatus() (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantTrialJobStatus, got); len(diff) != 0 { + t.Errorf("Unexpected trial job status from GetDeployedJobStatus() (-want,+got):\n%s", diff) } - } + }) } } @@ -154,6 +144,7 @@ func newFakeJob() *batchv1.Job { }, } } + func newFakeDeployedJob(job interface{}) *unstructured.Unstructured { jobUnstructured, _ := util.ConvertObjectToUnstructured(job) diff --git a/pkg/suggestion/v1beta1/goptuna/converter_test.go b/pkg/suggestion/v1beta1/goptuna/converter_test.go index 1e3189a773d..d92b0f391ca 100644 --- a/pkg/suggestion/v1beta1/goptuna/converter_test.go +++ b/pkg/suggestion/v1beta1/goptuna/converter_test.go @@ -17,48 +17,44 @@ limitations under the License. package suggestion_goptuna_v1beta1 import ( - "reflect" "testing" "github.com/c-bata/goptuna" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" api_v1_beta1 "github.com/kubeflow/katib/pkg/apis/manager/v1beta1" ) func Test_toGoptunaDirection(t *testing.T) { - for _, tt := range []struct { - name string + for name, tc := range map[string]struct { objectiveType api_v1_beta1.ObjectiveType - expected goptuna.StudyDirection + wantDirection goptuna.StudyDirection }{ - { - name: "minimize", + "minimize": { objectiveType: api_v1_beta1.ObjectiveType_MINIMIZE, - expected: goptuna.StudyDirectionMinimize, + wantDirection: goptuna.StudyDirectionMinimize, }, - { - name: "maximize", + "maximize": { objectiveType: api_v1_beta1.ObjectiveType_MAXIMIZE, - expected: goptuna.StudyDirectionMaximize, + wantDirection: goptuna.StudyDirectionMaximize, }, } { - t.Run(tt.name, func(t *testing.T) { - got := toGoptunaDirection(tt.objectiveType) - if got != tt.expected { - t.Errorf("toGoptunaDirection() got = %v, want %v", got, tt.expected) + t.Run(name, func(t *testing.T) { + got := toGoptunaDirection(tc.objectiveType) + if diff := cmp.Diff(tc.wantDirection, got); len(diff) != 0 { + t.Errorf("Unexpected direction from toGoptunaDirection (-want,+got):\n%s", diff) } }) } } func Test_toGoptunaSearchSpace(t *testing.T) { - tests := []struct { - name string - parameters []*api_v1_beta1.ParameterSpec - want map[string]interface{} - wantErr bool + cases := map[string]struct { + parameters []*api_v1_beta1.ParameterSpec + wantSearchSpace map[string]interface{} + wantError error }{ - { - name: "Double parameter type", + "Double parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-double", @@ -69,16 +65,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-double": goptuna.UniformDistribution{ High: 5.5, Low: 1.5, }, }, - wantErr: false, }, - { - name: "Double parameter type with step", + "Double parameter type with step": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-double", @@ -90,17 +84,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-double": goptuna.DiscreteUniformDistribution{ High: 5.5, Low: 1.5, Q: 0.5, }, }, - wantErr: false, }, - { - name: "Int parameter type", + "Int parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-int", @@ -111,16 +103,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-int": goptuna.IntUniformDistribution{ High: 5, Low: 1, }, }, - wantErr: false, }, - { - name: "Int parameter type with step", + "Int parameter type with step": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-int", @@ -132,17 +122,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-int": goptuna.StepIntUniformDistribution{ High: 5, Low: 1, Step: 2, }, }, - wantErr: false, }, - { - name: "Discrete parameter type", + "Discrete parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-discrete", @@ -152,15 +140,13 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-discrete": goptuna.CategoricalDistribution{ Choices: []string{"3", "2", "6"}, }, }, - wantErr: false, }, - { - name: "Categorical parameter type", + "Categorical parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-categorical", @@ -170,23 +156,21 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-categorical": goptuna.CategoricalDistribution{ Choices: []string{"cat1", "cat2", "cat3"}, }, }, - wantErr: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := toGoptunaSearchSpace(tt.parameters) - if (err != nil) != tt.wantErr { - t.Errorf("toGoptunaSearchSpace() error = %v, wantErr %v", err, tt.wantErr) - return + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := toGoptunaSearchSpace(tc.parameters) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from toGoptunaSearchSpace (-want,+got):\n%s", diff) } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("toGoptunaSearchSpace() got = %v, want %v", got, tt.want) + if diff := cmp.Diff(tc.wantSearchSpace, got); len(diff) != 0 { + t.Errorf("Unexpected search space from toGoptunaSearchSpace (-want,+got):\n%s", diff) } }) } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook.go b/pkg/webhook/v1beta1/pod/inject_webhook.go index f48b4b96a07..f5270528848 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "path/filepath" "strconv" @@ -47,6 +48,13 @@ import ( var log = logf.Log.WithName("injector-webhook") +var ( + errInvalidOwnerAPIVersion = errors.New("invalid owner API version") + errInvalidSuggestionName = errors.New("invalid suggestion name") + errPodNotBelongToKatibJob = errors.New("pod does not belong to Katib Job") + errFailedToGetTrialTemplateJob = errors.New("unable to get Job in the trialTemplate") +) + // SidecarInjector injects metrics collect sidecar to the primary pod. type SidecarInjector struct { client client.Client @@ -259,7 +267,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // Get group and version from owner API version gv, err := schema.ParseGroupVersion(owners[i].APIVersion) if err != nil { - return "", "", err + return "", "", fmt.Errorf("%w: %w", errInvalidOwnerAPIVersion, err) } gvk := schema.GroupVersionKind{ Group: gv.Group, @@ -272,7 +280,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // Nested object namespace must be equal to object namespace err = s.client.Get(context.TODO(), apitypes.NamespacedName{Name: owners[i].Name, Namespace: namespace}, nestedJob) if err != nil { - return "", "", err + return "", "", fmt.Errorf("%w: %w", errFailedToGetTrialTemplateJob, err) } // Recursively search for Trial ownership in nested object jobKind, jobName, err = s.getKatibJob(nestedJob, namespace) @@ -285,7 +293,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // If jobKind is empty after the loop, Trial doesn't own the object if jobKind == "" { - return "", "", errors.New("The Pod doesn't belong to Katib Job") + return "", "", errPodNotBelongToKatibJob } return jobKind, jobName, nil @@ -322,7 +330,7 @@ func (s *SidecarInjector) getMetricsCollectorArgs(trial *trialsv1beta1.Trial, me suggestion := &suggestionsv1beta1.Suggestion{} err := s.client.Get(context.TODO(), apitypes.NamespacedName{Name: suggestionName, Namespace: trial.Namespace}, suggestion) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", errInvalidSuggestionName, err) } args = append(args, "-s-earlystop", util.GetEarlyStoppingEndpoint(suggestion)) } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook_test.go b/pkg/webhook/v1beta1/pod/inject_webhook_test.go index ab1646b6769..5120b276990 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook_test.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook_test.go @@ -20,17 +20,17 @@ import ( "context" "fmt" "path/filepath" - "reflect" "sync" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/onsi/gomega" appsv1 "k8s.io/api/apps/v1" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" @@ -76,16 +76,15 @@ func TestWrapWorkerContainer(t *testing.T) { metricsFile := "metric.log" - testCases := []struct { - trial *trialsv1beta1.Trial - pod *v1.Pod - metricsFile string - pathKind common.FileSystemKind - expectedPod *v1.Pod - err bool - testDescription string + cases := map[string]struct { + trial *trialsv1beta1.Trial + pod *v1.Pod + metricsFile string + pathKind common.FileSystemKind + wantPod *v1.Pod + wantError error }{ - { + "Tensorflow container without sh -c": { trial: trial, pod: &v1.Pod{ Spec: v1.PodSpec{ @@ -101,7 +100,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, metricsFile: metricsFile, pathKind: common.FileKind, - expectedPod: &v1.Pod{ + wantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -116,10 +115,8 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - err: false, - testDescription: "Tensorflow container without sh -c", }, - { + "Tensorflow container with sh -c": { trial: trial, pod: &v1.Pod{ Spec: v1.PodSpec{ @@ -136,7 +133,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, metricsFile: metricsFile, pathKind: common.FileKind, - expectedPod: &v1.Pod{ + wantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -151,10 +148,8 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - err: false, - testDescription: "Tensorflow container with sh -c", }, - { + "Training pod doesn't have primary container": { trial: trial, pod: &v1.Pod{ Spec: v1.PodSpec{ @@ -165,11 +160,19 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - pathKind: common.FileKind, - err: true, - testDescription: "Training pod doesn't have primary container", + pathKind: common.FileKind, + wantPod: &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "not-primary-container", + }, + }, + }, + }, + wantError: errPrimaryContainerNotFound, }, - { + "Container with early stopping command": { trial: func() *trialsv1beta1.Trial { t := trial.DeepCopy() t.Spec.EarlyStoppingRules = []common.EarlyStoppingRule{ @@ -195,7 +198,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, metricsFile: metricsFile, pathKind: common.FileKind, - expectedPod: &v1.Pod{ + wantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -214,23 +217,19 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - err: false, - testDescription: "Container with early stopping command", }, } - for _, c := range testCases { - err := wrapWorkerContainer(c.trial, c.pod, c.trial.Namespace, c.metricsFile, c.pathKind) - if c.err && err == nil { - t.Errorf("Case %s failed. Expected error, got nil", c.testDescription) - } else if !c.err { - if err != nil { - t.Errorf("Case %s failed. Expected nil, got error: %v", c.testDescription, err) - } else if !equality.Semantic.DeepEqual(c.pod.Spec.Containers, c.expectedPod.Spec.Containers) { - t.Errorf("Case %s failed. Expected pod: %v, got: %v", - c.testDescription, c.expectedPod.Spec.Containers, c.pod.Spec.Containers) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + err := wrapWorkerContainer(tc.trial, tc.pod, tc.trial.Namespace, tc.metricsFile, tc.pathKind) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from wrapWorkerContainer (-want,+got):\n%s", diff) } - } + if diff := cmp.Diff(tc.wantPod.Spec.Containers, tc.pod.Spec.Containers); len(diff) != 0 { + t.Errorf("Unexpected pod from wrapWorkerContainer (-want,+got):\n%s", diff) + } + }) } } @@ -318,17 +317,16 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, } - testCases := []struct { + cases := map[string]struct { trial *trialsv1beta1.Trial metricNames string mCSpec common.MetricsCollectorSpec earlyStoppingRules []string katibConfig configv1beta1.MetricsCollectorConfig - expectedArgs []string - name string - err bool + wantArgs []string + wantError error }{ - { + "StdOut MC": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -339,7 +337,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { katibConfig: configv1beta1.MetricsCollectorConfig{ WaitAllProcesses: &waitAllProcessesValue, }, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -348,9 +346,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-format", string(common.TextFormat), "-w", "false", }, - name: "StdOut MC", }, - { + "File MC with Filter": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -371,7 +368,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -380,9 +377,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-f", "{mn1: ([a-b]), mv1: [0-9]};{mn2: ([a-b]), mv2: ([0-9])}", "-format", string(common.TextFormat), }, - name: "File MC with Filter", }, - { + "File MC with Json Format": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -397,7 +393,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -405,9 +401,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-path", testPath, "-format", string(common.JsonFormat), }, - name: "File MC with Json Format", }, - { + "Tf Event MC": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -421,16 +416,15 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, "-path", testPath, }, - name: "Tf Event MC", }, - { + "Custom MC without Path": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -439,15 +433,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, }, - name: "Custom MC without Path", }, - { + "Custom MC with Path": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -461,16 +454,15 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, "-path", testPath, }, - name: "Custom MC with Path", }, - { + "Prometheus MC without Path": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -479,15 +471,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, }, - name: "Prometheus MC without Path", }, - { + "Trial with EarlyStopping rules": { trial: testTrial, metricNames: testMetricName, mCSpec: common.MetricsCollectorSpec{ @@ -497,7 +488,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, earlyStoppingRules: earlyStoppingRules, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -508,9 +499,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-stop-rule", earlyStoppingRules[1], "-s-earlystop", katibEarlyStopAddress, }, - name: "Trial with EarlyStopping rules", }, - { + "Trial with invalid Experiment label name. Suggestion is not created": { trial: func() *trialsv1beta1.Trial { trial := testTrial.DeepCopy() trial.ObjectMeta.Labels[consts.LabelExperimentName] = "invalid-name" @@ -523,8 +513,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, earlyStoppingRules: earlyStoppingRules, katibConfig: configv1beta1.MetricsCollectorConfig{}, - name: "Trial with invalid Experiment label name. Suggestion is not created", - err: true, + wantError: errInvalidSuggestionName, }, } @@ -535,16 +524,16 @@ func TestGetMetricsCollectorArgs(t *testing.T) { return c.Get(context.TODO(), types.NamespacedName{Namespace: testNamespace, Name: testSuggestionName}, testSuggestion) }, timeout).ShouldNot(gomega.HaveOccurred()) - for _, tc := range testCases { - args, err := si.getMetricsCollectorArgs(tc.trial, tc.metricNames, tc.mCSpec, tc.katibConfig, tc.earlyStoppingRules) - - if !tc.err && err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.name, err) - } else if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.name) - } else if !tc.err && !reflect.DeepEqual(tc.expectedArgs, args) { - t.Errorf("Case %v failed. ExpectedArgs: %v, got %v", tc.name, tc.expectedArgs, args) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := si.getMetricsCollectorArgs(tc.trial, tc.metricNames, tc.mCSpec, tc.katibConfig, tc.earlyStoppingRules) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from getMetricsCollectorArgs (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantArgs, got); len(diff) != 0 { + t.Errorf("Unexpected args from getMetricsCollectorArgs (-want,+got):\n%s", diff) + } + }) } } @@ -582,13 +571,13 @@ func TestNeedWrapWorkerContainer(t *testing.T) { func TestMutateMetricsCollectorVolume(t *testing.T) { tc := struct { pod v1.Pod - expectedPod v1.Pod - JobKind string - MountPath string - SidecarContainerName string - PrimaryContainerName string + wantPod v1.Pod + jobKind string + mountPath string + sidecarContainerName string + primaryContainerName string pathKind common.FileSystemKind - err bool + wantError error }{ pod: v1.Pod{ Spec: v1.PodSpec{ @@ -605,7 +594,7 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { }, }, }, - expectedPod: v1.Pod{ + wantPod: v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -640,45 +629,48 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { }, }, }, - MountPath: common.DefaultFilePath, - SidecarContainerName: "metrics-collector", - PrimaryContainerName: "train-job", + mountPath: common.DefaultFilePath, + sidecarContainerName: "metrics-collector", + primaryContainerName: "train-job", pathKind: common.FileKind, } err := mutateMetricsCollectorVolume( &tc.pod, - tc.MountPath, - tc.SidecarContainerName, - tc.PrimaryContainerName, + tc.mountPath, + tc.sidecarContainerName, + tc.primaryContainerName, tc.pathKind) - if err != nil { - t.Errorf("mutateMetricsCollectorVolume failed: %v", err) - } else if !equality.Semantic.DeepEqual(tc.pod, tc.expectedPod) { - t.Errorf("Expected pod %v, got %v", tc.expectedPod, tc.pod) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from mutateMetricsCollectorVolume (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantPod, tc.pod); len(diff) != 0 { + t.Errorf("Unexpected pod from mutateMetricsCollectorVolume (-want,+got):\n%s", diff) } } func TestGetSidecarContainerName(t *testing.T) { - testCases := []struct { - collectorKind common.CollectorKind - expectedCollectorKind string + cases := map[string]struct { + collectorKind common.CollectorKind + wantCollectorKind string }{ - { - collectorKind: common.StdOutCollector, - expectedCollectorKind: mccommon.MetricLoggerCollectorContainerName, + "Valid run with StdOutCollector": { + collectorKind: common.StdOutCollector, + wantCollectorKind: mccommon.MetricLoggerCollectorContainerName, }, - { - collectorKind: common.TfEventCollector, - expectedCollectorKind: mccommon.MetricCollectorContainerName, + "Valid run with TfEventCollector": { + collectorKind: common.TfEventCollector, + wantCollectorKind: mccommon.MetricCollectorContainerName, }, } - for _, tc := range testCases { - collectorKind := getSidecarContainerName(tc.collectorKind) - if collectorKind != tc.expectedCollectorKind { - t.Errorf("Expected Collector Kind: %v, got %v", tc.expectedCollectorKind, collectorKind) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + collectorKind := getSidecarContainerName(tc.collectorKind) + if collectorKind != tc.wantCollectorKind { + t.Errorf("Expected Collector Kind: %v, got %v", tc.wantCollectorKind, collectorKind) + } + }) } } @@ -720,16 +712,15 @@ func TestGetKatibJob(t *testing.T) { deployName := "deploy-name" jobName := "job-name" - testCases := []struct { - pod *v1.Pod - job *batchv1.Job - deployment *appsv1.Deployment - expectedJobKind string - expectedJobName string - err bool - testDescription string + cases := map[string]struct { + pod *v1.Pod + job *batchv1.Job + deployment *appsv1.Deployment + wantJobKind string + wantJobName string + wantError error }{ - { + "Valid run with ownership sequence: Trial -> Job -> Pod": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -770,12 +761,10 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - expectedJobKind: "Job", - expectedJobName: jobName + "-1", - err: false, - testDescription: "Valid run with ownership sequence: Trial -> Job -> Pod", + wantJobKind: "Job", + wantJobName: jobName + "-1", }, - { + "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -849,12 +838,10 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - expectedJobKind: "Deployment", - expectedJobName: deployName + "-2", - err: false, - testDescription: "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod", + wantJobKind: "Deployment", + wantJobName: deployName + "-2", }, - { + "Run for not Trial's pod with ownership sequence: Job -> Pod": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -887,10 +874,9 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - err: true, - testDescription: "Run for not Trial's pod with ownership sequence: Job -> Pod", + wantError: errPodNotBelongToKatibJob, }, - { + "Run when Pod owns Job that doesn't exists": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -904,10 +890,9 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - err: true, - testDescription: "Run when Pod owns Job that doesn't exists", + wantError: errFailedToGetTrialTemplateJob, }, - { + "Run when Pod owns Job with invalid API version": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -921,64 +906,63 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - err: true, - testDescription: "Run when Pod owns Job with invalid API version", + wantError: errInvalidOwnerAPIVersion, }, } - for _, tc := range testCases { - // Create Job if it is needed - if tc.job != nil { - jobUnstr, err := util.ConvertObjectToUnstructured(tc.job) - gvk := schema.GroupVersionKind{ - Group: "batch", - Version: "v1", - Kind: "Job", - } - jobUnstr.SetGroupVersionKind(gvk) - if err != nil { - t.Errorf("ConvertObjectToUnstructured error %v", err) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + // Create Job if it is needed + if tc.job != nil { + jobUnstr, err := util.ConvertObjectToUnstructured(tc.job) + gvk := schema.GroupVersionKind{ + Group: "batch", + Version: "v1", + Kind: "Job", + } + jobUnstr.SetGroupVersionKind(gvk) + if err != nil { + t.Errorf("ConvertObjectToUnstructured error %v", err) + } - g.Expect(c.Create(context.TODO(), jobUnstr)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Create(context.TODO(), jobUnstr)).NotTo(gomega.HaveOccurred()) - // Wait that Job is created - g.Eventually(func() error { - return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.job.Name}, jobUnstr) - }, timeout).ShouldNot(gomega.HaveOccurred()) - } + // Wait that Job is created + g.Eventually(func() error { + return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.job.Name}, jobUnstr) + }, timeout).ShouldNot(gomega.HaveOccurred()) + } - // Create Deployment if it is needed - if tc.deployment != nil { - g.Expect(c.Create(context.TODO(), tc.deployment)).NotTo(gomega.HaveOccurred()) + // Create Deployment if it is needed + if tc.deployment != nil { + g.Expect(c.Create(context.TODO(), tc.deployment)).NotTo(gomega.HaveOccurred()) - // Wait that Deployment is created - g.Eventually(func() error { - return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.deployment.Name}, tc.deployment) - }, timeout).ShouldNot(gomega.HaveOccurred()) - } + // Wait that Deployment is created + g.Eventually(func() error { + return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.deployment.Name}, tc.deployment) + }, timeout).ShouldNot(gomega.HaveOccurred()) + } - object, _ := util.ConvertObjectToUnstructured(tc.pod) - jobKind, jobName, err := si.getKatibJob(object, namespace) - if !tc.err && err != nil { - t.Errorf("Case %v failed. Error %v", tc.testDescription, err) - } else if !tc.err && (tc.expectedJobKind != jobKind || tc.expectedJobName != jobName) { - t.Errorf("Case %v failed. Expected jobKind %v, got %v, Expected jobName %v, got %v", - tc.testDescription, tc.expectedJobKind, jobKind, tc.expectedJobName, jobName) - } else if tc.err && err == nil { - t.Errorf("Expected error got nil") - } + object, _ := util.ConvertObjectToUnstructured(tc.pod) + jobKind, jobName, err := si.getKatibJob(object, namespace) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from getKatibJob (-want,+got):\n%s", diff) + } + if tc.wantError == nil && (tc.wantJobKind != jobKind || tc.wantJobName != jobName) { + t.Errorf("Unexpected error from getKatibJob, expected jobKind %v, got %v, expected jobName %v, got %v", + tc.wantJobKind, jobKind, tc.wantJobName, jobName) + } + }) } } func TestIsPrimaryPod(t *testing.T) { - testCases := []struct { + cases := map[string]struct { podLabels map[string]string primaryPodLabels map[string]string isPrimary bool - testDescription string }{ - { + "Pod contains all labels from primary pod labels": { podLabels: map[string]string{ "test-key-1": "test-value-1", "test-key-2": "test-value-2", @@ -988,10 +972,9 @@ func TestIsPrimaryPod(t *testing.T) { "test-key-1": "test-value-1", "test-key-2": "test-value-2", }, - isPrimary: true, - testDescription: "Pod contains all labels from primary pod labels", + isPrimary: true, }, - { + "Pod doesn't contain primary label": { podLabels: map[string]string{ "test-key-1": "test-value-1", }, @@ -999,26 +982,26 @@ func TestIsPrimaryPod(t *testing.T) { "test-key-1": "test-value-1", "test-key-2": "test-value-2", }, - isPrimary: false, - testDescription: "Pod doesn't contain primary label", + isPrimary: false, }, - { + "Pod contains label with incorrect value": { podLabels: map[string]string{ "test-key-1": "invalid", }, primaryPodLabels: map[string]string{ "test-key-1": "test-value-1", }, - isPrimary: false, - testDescription: "Pod contains label with incorrect value", + isPrimary: false, }, } - for _, tc := range testCases { - isPrimary := isPrimaryPod(tc.podLabels, tc.primaryPodLabels) - if isPrimary != tc.isPrimary { - t.Errorf("Case %v. Expected isPrimary %v, got %v", tc.testDescription, tc.isPrimary, isPrimary) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + isPrimary := isPrimaryPod(tc.podLabels, tc.primaryPodLabels) + if isPrimary != tc.isPrimary { + t.Errorf("Case %v. Expected isPrimary %v, got %v", name, tc.isPrimary, isPrimary) + } + }) } } @@ -1029,13 +1012,12 @@ func TestMutatePodMetadata(t *testing.T) { consts.LabelTrialName: "test-trial", } - testCases := []struct { - pod *v1.Pod - trial *trialsv1beta1.Trial - mutatedPod *v1.Pod - testDescription string + cases := map[string]struct { + pod *v1.Pod + trial *trialsv1beta1.Trial + mutatedPod *v1.Pod }{ - { + "Mutated Pod should contain label from the origin Pod and Trial": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Labels: map[string]string{ @@ -1056,14 +1038,15 @@ func TestMutatePodMetadata(t *testing.T) { Labels: mutatedPodLabels, }, }, - testDescription: "Mutated Pod should contain label from the origin Pod and Trial", }, } - for _, tc := range testCases { - mutatePodMetadata(tc.pod, tc.trial) - if !reflect.DeepEqual(tc.mutatedPod, tc.pod) { - t.Errorf("Case %v. Expected Pod %v, got %v", tc.testDescription, tc.mutatedPod, tc.pod) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + mutatePodMetadata(tc.pod, tc.trial) + if diff := cmp.Diff(tc.mutatedPod, tc.pod); len(diff) != 0 { + t.Errorf("Unexpected pod from mutatePodMetadata (-want,+got):\n%s", diff) + } + }) } } diff --git a/pkg/webhook/v1beta1/pod/utils.go b/pkg/webhook/v1beta1/pod/utils.go index 6381bf6d895..6a25a2441e7 100644 --- a/pkg/webhook/v1beta1/pod/utils.go +++ b/pkg/webhook/v1beta1/pod/utils.go @@ -18,6 +18,7 @@ package pod import ( "context" + "errors" "fmt" "path/filepath" "strings" @@ -35,6 +36,8 @@ import ( mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" ) +var errPrimaryContainerNotFound = errors.New("unable to find primary container in mutated pod containers") + func isPrimaryPod(podLabels, primaryLabels map[string]string) bool { for primaryKey, primaryValue := range primaryLabels { @@ -190,8 +193,7 @@ func wrapWorkerContainer(trial *trialsv1beta1.Trial, pod *v1.Pod, namespace, c.Command = command c.Args = []string{argsStr} } else { - return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v", - trial.Spec.PrimaryContainerName, pod.Spec.Containers) + return fmt.Errorf("%w: primary container: %v, mutated pod containers: %v", errPrimaryContainerNotFound, trial.Spec.PrimaryContainerName, pod.Spec.Containers) } return nil }