diff --git a/api/lmes/v1alpha1/lmevaljob_types.go b/api/lmes/v1alpha1/lmevaljob_types.go index 26b782d7..a7138a0c 100644 --- a/api/lmes/v1alpha1/lmevaljob_types.go +++ b/api/lmes/v1alpha1/lmevaljob_types.go @@ -247,7 +247,7 @@ type LMEvalJobSpec struct { LogSamples *bool `json:"logSamples,omitempty"` // Batch size for the evaluation. This is used by the models that run and are loaded // locally and not apply for the commercial APIs. - BatchSize *int `json:"batchSize,omitempty"` + BatchSize *string `json:"batchSize,omitempty"` // Specify extra information for the lm-eval job's pod // +optional Pod *LMEvalPodSpec `json:"pod,omitempty"` diff --git a/api/lmes/v1alpha1/zz_generated.deepcopy.go b/api/lmes/v1alpha1/zz_generated.deepcopy.go index bc1014d2..ffc83f10 100644 --- a/api/lmes/v1alpha1/zz_generated.deepcopy.go +++ b/api/lmes/v1alpha1/zz_generated.deepcopy.go @@ -174,7 +174,7 @@ func (in *LMEvalJobSpec) DeepCopyInto(out *LMEvalJobSpec) { } if in.BatchSize != nil { in, out := &in.BatchSize, &out.BatchSize - *out = new(int) + *out = new(string) **out = **in } if in.Pod != nil { diff --git a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml index eb78a026..8d232b84 100644 --- a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml +++ b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml @@ -47,7 +47,7 @@ spec: description: |- Batch size for the evaluation. This is used by the models that run and are loaded locally and not apply for the commercial APIs. - type: integer + type: string genArgs: description: Map to `--gen_kwargs` parameter for the underlying library. items: diff --git a/controllers/lmes/config.go b/controllers/lmes/config.go index 53b47002..9d75038f 100644 --- a/controllers/lmes/config.go +++ b/controllers/lmes/config.go @@ -43,7 +43,7 @@ type serviceOptions struct { PodCheckingInterval time.Duration ImagePullPolicy corev1.PullPolicy MaxBatchSize int - DefaultBatchSize int + DefaultBatchSize string DetectDevice bool } diff --git a/controllers/lmes/constants.go b/controllers/lmes/constants.go index e77cbf7c..252523cf 100644 --- a/controllers/lmes/constants.go +++ b/controllers/lmes/constants.go @@ -38,7 +38,7 @@ const ( DefaultPodCheckingInterval = time.Second * 10 DefaultImagePullPolicy = corev1.PullAlways DefaultMaxBatchSize = 24 - DefaultBatchSize = 8 + DefaultBatchSize = "1" DefaultDetectDevice = true ServiceName = "LMES" ) diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index 344669a4..e610c750 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -22,6 +22,7 @@ import ( "fmt" "maps" "slices" + "strconv" "strings" "sync" "time" @@ -818,6 +819,42 @@ func mergeMapWithFilters(dest, src map[string]string, prefixFilters []string, lo } } +func validateBatchSize(input string, maxBatchSize int, log logr.Logger) string { + + maxBatchSizeString := strconv.Itoa(maxBatchSize) + + if input == "auto" { + // No validation needed, return original + return input + } + + // Validate "auto:N" style batch size + if strings.HasPrefix(input, "auto:") { + autoN := strings.TrimPrefix(input, "auto:") + if n, err := strconv.Atoi(autoN); err == nil && n > 0 { + // If N is a positive integer, use it and ignore maxBatchSize, since is now the maximum batch size + return input + } + // If N is an invalid integer, use "auto:maxBatchSize" + log.Info(input + " not supported. Using auto:" + maxBatchSizeString) + return "auto:" + maxBatchSizeString + } + + // Validate N batch size + if n, err := strconv.Atoi(input); err == nil && n > 0 { + // If N is valid, but larger than maxBatchSize, set it to maximum batch size + if n > maxBatchSize { + log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead") + return maxBatchSizeString + } + // If N is valid, use it + return strconv.Itoa(n) + } + + log.Info("invalid batchSize " + input + " using batch size " + DefaultBatchSize) + return DefaultBatchSize +} + func generateArgs(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string { if job == nil { return nil @@ -853,15 +890,12 @@ func generateArgs(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr } // --batch_size var batchSize = svcOpts.DefaultBatchSize - if job.Spec.BatchSize != nil && *job.Spec.BatchSize > 0 { - batchSize = *job.Spec.BatchSize - } - // This could be done in the webhook if it's enabled. - if batchSize > svcOpts.MaxBatchSize { - batchSize = svcOpts.MaxBatchSize - log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead") + if job.Spec.BatchSize != nil { + // This could be done in the webhook if it's enabled. + batchSize = validateBatchSize(*job.Spec.BatchSize, svcOpts.MaxBatchSize, log) } - cmds = append(cmds, "--batch_size", fmt.Sprintf("%d", batchSize)) + + cmds = append(cmds, "--batch_size", batchSize) return []string{"sh", "-ec", strings.Join(cmds, " ")} } diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index 9934e642..33ab08dd 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -660,7 +660,7 @@ func Test_GenerateArgBatchSize(t *testing.T) { DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, MaxBatchSize: 20, - DefaultBatchSize: 4, + DefaultBatchSize: "4", } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -686,11 +686,11 @@ func Test_GenerateArgBatchSize(t *testing.T) { // no batchSize in the job, use default batchSize assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.DefaultBatchSize), + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + svcOpts.DefaultBatchSize, }, generateArgs(svcOpts, job, log)) // exceed the max-batch-size, use max-batch-size - var biggerBatchSize = 30 + var biggerBatchSize = "30" job.Spec.BatchSize = &biggerBatchSize assert.Equal(t, []string{ "sh", "-ec", @@ -698,7 +698,7 @@ func Test_GenerateArgBatchSize(t *testing.T) { }, generateArgs(svcOpts, job, log)) // normal batchSize - var normalBatchSize = 16 + var normalBatchSize = "16" job.Spec.BatchSize = &normalBatchSize assert.Equal(t, []string{ "sh", "-ec", @@ -752,7 +752,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { // one TaskRecipe assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8", + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size " + DefaultBatchSize, }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ @@ -777,7 +777,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { // one TaskRecipe assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0,tr_1 --include_path /opt/app-root/src/my_tasks --batch_size 8", + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0,tr_1 --include_path /opt/app-root/src/my_tasks --batch_size " + DefaultBatchSize, }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ @@ -836,7 +836,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8", + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size " + DefaultBatchSize, }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ @@ -1365,3 +1365,30 @@ func Test_PVCPreference(t *testing.T) { assert.Equal(t, expect, newPod) } + +func Test_ValidateBatchSize(t *testing.T) { + maxBatchSize := 32 + logger := log.Log.WithName("tests") + scenarios := []struct { + provided string + validated string + }{ + {"5", "5"}, + {"auto", "auto"}, + {"auto:3", "auto:3"}, + {"auto:0", "auto:" + strconv.Itoa(maxBatchSize)}, + {"auto:-5", "auto:" + strconv.Itoa(maxBatchSize)}, + {"64", strconv.Itoa(maxBatchSize)}, + {"-5", DefaultBatchSize}, + {"invalid", DefaultBatchSize}, + {"0", DefaultBatchSize}, + {"auto:auto", "auto:" + strconv.Itoa(maxBatchSize)}, + } + + for _, scenario := range scenarios { + result := validateBatchSize(scenario.provided, maxBatchSize, logger) + if result != scenario.validated { + t.Errorf("validateBatchSize(%q) = %q; want %q", scenario.provided, result, scenario.validated) + } + } +}