Skip to content

Commit

Permalink
RHOAIENG-14773: Change batchSize to string in order to support "auto"…
Browse files Browse the repository at this point in the history
… and "auto:N" (trustyai-explainability#350)

* Change batchSize from int to str to support "auto" and "auto:N"

* Remove redundant conversion of maxBatchSizeString
  • Loading branch information
ruivieira authored Oct 30, 2024
1 parent b76a51d commit 08effb0
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 20 deletions.
2 changes: 1 addition & 1 deletion api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ type LMEvalJobSpec struct {
LogSamples *bool `json:"logSamples,omitempty"`
// Batch size for the evaluation. This is used by the models that run and are loaded
// locally and not apply for the commercial APIs.
BatchSize *int `json:"batchSize,omitempty"`
BatchSize *string `json:"batchSize,omitempty"`
// Specify extra information for the lm-eval job's pod
// +optional
Pod *LMEvalPodSpec `json:"pod,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ spec:
description: |-
Batch size for the evaluation. This is used by the models that run and are loaded
locally and not apply for the commercial APIs.
type: integer
type: string
genArgs:
description: Map to `--gen_kwargs` parameter for the underlying library.
items:
Expand Down
2 changes: 1 addition & 1 deletion controllers/lmes/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type serviceOptions struct {
PodCheckingInterval time.Duration
ImagePullPolicy corev1.PullPolicy
MaxBatchSize int
DefaultBatchSize int
DefaultBatchSize string
DetectDevice bool
}

Expand Down
2 changes: 1 addition & 1 deletion controllers/lmes/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const (
DefaultPodCheckingInterval = time.Second * 10
DefaultImagePullPolicy = corev1.PullAlways
DefaultMaxBatchSize = 24
DefaultBatchSize = 8
DefaultBatchSize = "1"
DefaultDetectDevice = true
ServiceName = "LMES"
)
50 changes: 42 additions & 8 deletions controllers/lmes/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"maps"
"slices"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -818,6 +819,42 @@ func mergeMapWithFilters(dest, src map[string]string, prefixFilters []string, lo
}
}

func validateBatchSize(input string, maxBatchSize int, log logr.Logger) string {

maxBatchSizeString := strconv.Itoa(maxBatchSize)

if input == "auto" {
// No validation needed, return original
return input
}

// Validate "auto:N" style batch size
if strings.HasPrefix(input, "auto:") {
autoN := strings.TrimPrefix(input, "auto:")
if n, err := strconv.Atoi(autoN); err == nil && n > 0 {
// If N is a positive integer, use it and ignore maxBatchSize, since is now the maximum batch size
return input
}
// If N is an invalid integer, use "auto:maxBatchSize"
log.Info(input + " not supported. Using auto:" + maxBatchSizeString)
return "auto:" + maxBatchSizeString
}

// Validate N batch size
if n, err := strconv.Atoi(input); err == nil && n > 0 {
// If N is valid, but larger than maxBatchSize, set it to maximum batch size
if n > maxBatchSize {
log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead")
return maxBatchSizeString
}
// If N is valid, use it
return strconv.Itoa(n)
}

log.Info("invalid batchSize " + input + " using batch size " + DefaultBatchSize)
return DefaultBatchSize
}

func generateArgs(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string {
if job == nil {
return nil
Expand Down Expand Up @@ -853,15 +890,12 @@ func generateArgs(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr
}
// --batch_size
var batchSize = svcOpts.DefaultBatchSize
if job.Spec.BatchSize != nil && *job.Spec.BatchSize > 0 {
batchSize = *job.Spec.BatchSize
}
// This could be done in the webhook if it's enabled.
if batchSize > svcOpts.MaxBatchSize {
batchSize = svcOpts.MaxBatchSize
log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead")
if job.Spec.BatchSize != nil {
// This could be done in the webhook if it's enabled.
batchSize = validateBatchSize(*job.Spec.BatchSize, svcOpts.MaxBatchSize, log)
}
cmds = append(cmds, "--batch_size", fmt.Sprintf("%d", batchSize))

cmds = append(cmds, "--batch_size", batchSize)

return []string{"sh", "-ec", strings.Join(cmds, " ")}
}
Expand Down
41 changes: 34 additions & 7 deletions controllers/lmes/lmevaljob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ func Test_GenerateArgBatchSize(t *testing.T) {
DriverImage: "driver:latest",
ImagePullPolicy: corev1.PullAlways,
MaxBatchSize: 20,
DefaultBatchSize: 4,
DefaultBatchSize: "4",
}
var job = &lmesv1alpha1.LMEvalJob{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -686,19 +686,19 @@ func Test_GenerateArgBatchSize(t *testing.T) {
// no batchSize in the job, use default batchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.DefaultBatchSize),
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + svcOpts.DefaultBatchSize,
}, generateArgs(svcOpts, job, log))

// exceed the max-batch-size, use max-batch-size
var biggerBatchSize = 30
var biggerBatchSize = "30"
job.Spec.BatchSize = &biggerBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.MaxBatchSize),
}, generateArgs(svcOpts, job, log))

// normal batchSize
var normalBatchSize = 16
var normalBatchSize = "16"
job.Spec.BatchSize = &normalBatchSize
assert.Equal(t, []string{
"sh", "-ec",
Expand Down Expand Up @@ -752,7 +752,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) {
// one TaskRecipe
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size " + DefaultBatchSize,
}, generateArgs(svcOpts, job, log))

assert.Equal(t, []string{
Expand All @@ -777,7 +777,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) {
// one TaskRecipe
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0,tr_1 --include_path /opt/app-root/src/my_tasks --batch_size 8",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0,tr_1 --include_path /opt/app-root/src/my_tasks --batch_size " + DefaultBatchSize,
}, generateArgs(svcOpts, job, log))

assert.Equal(t, []string{
Expand Down Expand Up @@ -836,7 +836,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) {

assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size " + DefaultBatchSize,
}, generateArgs(svcOpts, job, log))

assert.Equal(t, []string{
Expand Down Expand Up @@ -1365,3 +1365,30 @@ func Test_PVCPreference(t *testing.T) {

assert.Equal(t, expect, newPod)
}

func Test_ValidateBatchSize(t *testing.T) {
maxBatchSize := 32
logger := log.Log.WithName("tests")
scenarios := []struct {
provided string
validated string
}{
{"5", "5"},
{"auto", "auto"},
{"auto:3", "auto:3"},
{"auto:0", "auto:" + strconv.Itoa(maxBatchSize)},
{"auto:-5", "auto:" + strconv.Itoa(maxBatchSize)},
{"64", strconv.Itoa(maxBatchSize)},
{"-5", DefaultBatchSize},
{"invalid", DefaultBatchSize},
{"0", DefaultBatchSize},
{"auto:auto", "auto:" + strconv.Itoa(maxBatchSize)},
}

for _, scenario := range scenarios {
result := validateBatchSize(scenario.provided, maxBatchSize, logger)
if result != scenario.validated {
t.Errorf("validateBatchSize(%q) = %q; want %q", scenario.provided, result, scenario.validated)
}
}
}

0 comments on commit 08effb0

Please sign in to comment.