From 6551b6a2244fe0f66b6ada26dfd68ace9fa4cc94 Mon Sep 17 00:00:00 2001 From: Yihong Wang Date: Wed, 19 Feb 2025 10:11:36 -0800 Subject: [PATCH] feat(LMES): Support custom template and prompt (#404) Expand the LMEvalJob CRD to support custom templates and system prompts. This is mainly for custom unitxt task recipes. Now, users can use the `template` and `systemPrompt` fields under the `taskRecipes` to specify the custom template and system prompt. Signed-off-by: Yihong Wang --- Dockerfile.lmes-job | 2 +- api/lmes/v1alpha1/lmevaljob_types.go | 79 ++++++- api/lmes/v1alpha1/zz_generated.deepcopy.go | 85 +++++++ cmd/lmes_driver/main.go | 44 ++-- .../trustyai.opendatahub.io_lmevaljobs.yaml | 63 ++++- controllers/lmes/driver/driver.go | 85 ++++++- controllers/lmes/driver/driver_test.go | 192 ++++++++++------ controllers/lmes/lmevaljob_controller.go | 100 ++++++-- controllers/lmes/lmevaljob_controller_test.go | 217 +++++++++++++++++- 9 files changed, 748 insertions(+), 119 deletions(-) diff --git a/Dockerfile.lmes-job b/Dockerfile.lmes-job index b3e0af93..55791180 100644 --- a/Dockerfile.lmes-job +++ b/Dockerfile.lmes-job @@ -20,7 +20,7 @@ RUN curl -L https://github.com/opendatahub-io/lm-evaluation-harness/archive/refs RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("class: !function " + task.__file__.replace("task.py", "task.Unitxt"))' > ./my_tasks/unitxt -ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server +ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src:/opt/app-root/src/server ENV HF_HOME=/opt/app-root/src/hf_home ENV UNITXT_CATALOGS=/opt/app-root/src/my_catalogs diff --git a/api/lmes/v1alpha1/lmevaljob_types.go b/api/lmes/v1alpha1/lmevaljob_types.go index 18887431..a1909a74 100644 --- a/api/lmes/v1alpha1/lmevaljob_types.go +++ b/api/lmes/v1alpha1/lmevaljob_types.go @@ -76,6 +76,57 @@ type Card struct { Custom string `json:"custom,omitempty"` } +type Template struct { + // Unitxt template ID + // +optional + Name string `json:"name,omitempty"` + // The name of the custom template in the custom field. Its value is a JSON string + // for a custom Unitxt template. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html + // to compose a custom template, store it as a JSON file by calling the + // add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog, + // and use the JSON content as the value here. + // +optional + Ref string `json:"ref,omitempty"` +} + +type SystemPrompt struct { + // Unitxt System Prompt id + Name string `json:"name,omitempty"` + // The name of the custom systemPrompt in the custom field. Its value is a custom system prompt string + Ref string `json:"ref,omitempty"` +} + +type CustomArtifact struct { + // Name of the custom artifact + Name string `json:"name"` + // Value of the custom artifact. It could be a JSON string or plain text + // depending on the artifact type + Value string `json:"value"` +} + +func (c *CustomArtifact) String() string { + return fmt.Sprintf("%s|%s", c.Name, c.Value) +} + +type CustomArtifacts struct { + Templates []CustomArtifact `json:"templates,omitempty"` + SystemPrompts []CustomArtifact `json:"systemPrompts,omitempty"` +} + +func (c *CustomArtifacts) GetTemplates() []CustomArtifact { + if c == nil { + return nil + } + return c.Templates +} + +func (c *CustomArtifacts) GetSystemPrompts() []CustomArtifact { + if c == nil { + return nil + } + return c.SystemPrompts +} + // Use a task recipe to form a custom task. It maps to the Unitxt Recipe // Find details of the Unitxt Recipe here: // https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe @@ -83,7 +134,11 @@ type TaskRecipe struct { // The Unitxt dataset card Card Card `json:"card"` // The Unitxt template - Template string `json:"template"` + // +optional + Template *Template `json:"template,omitempty"` + // The Unitxt System Prompt + // +optional + SystemPrompt *SystemPrompt `json:"systemPrompt,omitempty"` // The Unitxt Task // +optional Task *string `json:"task,omitempty"` @@ -109,11 +164,31 @@ type TaskList struct { TaskNames []string `json:"taskNames,omitempty"` // Task Recipes specifically for Unitxt TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"` + // Custom Unitxt artifacts that can be used in a TaskRecipe + CustomArtifacts *CustomArtifacts `json:"custom,omitempty"` } +// Use the tp_idx and sp_idx to point to the corresponding custom template +// and custom system_prompt func (t *TaskRecipe) String() string { var b strings.Builder - b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card.Name, t.Template)) + b.WriteString(fmt.Sprintf("card=%s", t.Card.Name)) + if t.Template != nil { + if t.Template.Name != "" { + b.WriteString(fmt.Sprintf(",template=%s", t.Template.Name)) + } else { + // refer to a custom template. add "templates." prefix + b.WriteString(fmt.Sprintf(",template=templates.%s", t.Template.Ref)) + } + } + if t.SystemPrompt != nil { + if t.SystemPrompt.Name != "" { + b.WriteString(fmt.Sprintf(",system_prompt=%s", t.SystemPrompt.Name)) + } else { + // refer to custom system prompt. add "system_prompts." prefix + b.WriteString(fmt.Sprintf(",system_prompt=system_prompts.%s", t.SystemPrompt.Ref)) + } + } if t.Task != nil { b.WriteString(fmt.Sprintf(",task=%s", *t.Task)) } diff --git a/api/lmes/v1alpha1/zz_generated.deepcopy.go b/api/lmes/v1alpha1/zz_generated.deepcopy.go index 7f19b79b..60b9a27c 100644 --- a/api/lmes/v1alpha1/zz_generated.deepcopy.go +++ b/api/lmes/v1alpha1/zz_generated.deepcopy.go @@ -55,6 +55,46 @@ func (in *Card) DeepCopy() *Card { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *CustomArtifact) DeepCopyInto(out *CustomArtifact) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CustomArtifact. +func (in *CustomArtifact) DeepCopy() *CustomArtifact { + if in == nil { + return nil + } + out := new(CustomArtifact) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *CustomArtifacts) DeepCopyInto(out *CustomArtifacts) { + *out = *in + if in.Templates != nil { + in, out := &in.Templates, &out.Templates + *out = make([]CustomArtifact, len(*in)) + copy(*out, *in) + } + if in.SystemPrompts != nil { + in, out := &in.SystemPrompts, &out.SystemPrompts + *out = make([]CustomArtifact, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CustomArtifacts. +func (in *CustomArtifacts) DeepCopy() *CustomArtifacts { + if in == nil { + return nil + } + out := new(CustomArtifacts) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *LMEvalContainer) DeepCopyInto(out *LMEvalContainer) { *out = *in @@ -397,6 +437,21 @@ func (in *PersistentVolumeClaimManaged) DeepCopy() *PersistentVolumeClaimManaged return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SystemPrompt) DeepCopyInto(out *SystemPrompt) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SystemPrompt. +func (in *SystemPrompt) DeepCopy() *SystemPrompt { + if in == nil { + return nil + } + out := new(SystemPrompt) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *TaskList) DeepCopyInto(out *TaskList) { *out = *in @@ -412,6 +467,11 @@ func (in *TaskList) DeepCopyInto(out *TaskList) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.CustomArtifacts != nil { + in, out := &in.CustomArtifacts, &out.CustomArtifacts + *out = new(CustomArtifacts) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TaskList. @@ -428,6 +488,16 @@ func (in *TaskList) DeepCopy() *TaskList { func (in *TaskRecipe) DeepCopyInto(out *TaskRecipe) { *out = *in out.Card = in.Card + if in.Template != nil { + in, out := &in.Template, &out.Template + *out = new(Template) + **out = **in + } + if in.SystemPrompt != nil { + in, out := &in.SystemPrompt, &out.SystemPrompt + *out = new(SystemPrompt) + **out = **in + } if in.Task != nil { in, out := &in.Task, &out.Task *out = new(string) @@ -469,3 +539,18 @@ func (in *TaskRecipe) DeepCopy() *TaskRecipe { in.DeepCopyInto(out) return out } + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Template) DeepCopyInto(out *Template) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Template. +func (in *Template) DeepCopy() *Template { + if in == nil { + return nil + } + out := new(Template) + in.DeepCopyInto(out) + return out +} diff --git a/cmd/lmes_driver/main.go b/cmd/lmes_driver/main.go index fa9e92f8..c099364f 100644 --- a/cmd/lmes_driver/main.go +++ b/cmd/lmes_driver/main.go @@ -50,21 +50,25 @@ func (t *strArrayArg) String() string { } var ( - taskRecipes strArrayArg - customCards strArrayArg - copy = flag.String("copy", "", "copy this binary to specified destination path") - getStatus = flag.Bool("get-status", false, "Get current status") - shutdown = flag.Bool("shutdown", false, "Shutdown the driver") - outputPath = flag.String("output-path", OutputPath, "output path") - detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU") - commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port") - downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3") - driverLog = ctrl.Log.WithName("driver") + taskRecipes strArrayArg + customCards strArrayArg + customTemplates strArrayArg + customSystemPrompts strArrayArg + copy = flag.String("copy", "", "copy this binary to specified destination path") + getStatus = flag.Bool("get-status", false, "Get current status") + shutdown = flag.Bool("shutdown", false, "Shutdown the driver") + outputPath = flag.String("output-path", OutputPath, "output path") + detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU") + commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port") + downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3") + driverLog = ctrl.Log.WithName("driver") ) func init() { flag.Var(&taskRecipes, "task-recipe", "task recipe") flag.Var(&customCards, "custom-card", "A JSON string represents a custom card") + flag.Var(&customTemplates, "custom-template", "A JSON string represents a custom template") + flag.Var(&customSystemPrompts, "custom-prompt", "A string represents a custom system_prompt") } func main() { @@ -107,15 +111,17 @@ func main() { } driverOpt := driver.DriverOption{ - Context: ctx, - OutputPath: *outputPath, - DetectDevice: *detectDevice, - Logger: driverLog, - TaskRecipes: taskRecipes, - CustomCards: customCards, - Args: args, - CommPort: *commPort, - DownloadAssetsS3: *downloadAssetsS3, + Context: ctx, + OutputPath: *outputPath, + DetectDevice: *detectDevice, + Logger: driverLog, + TaskRecipes: taskRecipes, + CustomCards: customCards, + CustomTemplates: customTemplates, + CustomSystemPrompt: customSystemPrompts, + Args: args, + CommPort: *commPort, + DownloadAssetsS3: *downloadAssetsS3, } driver, err := driver.NewDriver(&driverOpt) diff --git a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml index 5d1c723c..4d214dbf 100644 --- a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml +++ b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml @@ -4721,6 +4721,42 @@ spec: taskList: description: Evaluation task list properties: + custom: + description: Custom Unitxt artifacts that can be used in a TaskRecipe + properties: + systemPrompts: + items: + properties: + name: + description: Name of the custom artifact + type: string + value: + description: |- + Value of the custom artifact. It could be a JSON string or plain text + depending on the artifact type + type: string + required: + - name + - value + type: object + type: array + templates: + items: + properties: + name: + description: Name of the custom artifact + type: string + value: + description: |- + Value of the custom artifact. It could be a JSON string or plain text + depending on the artifact type + type: string + required: + - name + - value + type: object + type: array + type: object taskNames: description: TaskNames from lm-eval's task list items: @@ -4764,15 +4800,38 @@ spec: numDemos: description: Number of fewshot type: integer + systemPrompt: + description: The Unitxt System Prompt + properties: + name: + description: Unitxt System Prompt id + type: string + ref: + description: The name of the custom systemPrompt in + the custom field. Its value is a custom system prompt + string + type: string + type: object task: description: The Unitxt Task type: string template: description: The Unitxt template - type: string + properties: + name: + description: Unitxt template ID + type: string + ref: + description: |- + The name of the custom template in the custom field. Its value is a JSON string + for a custom Unitxt template. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html + to compose a custom template, store it as a JSON file by calling the + add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog, + and use the JSON content as the value here. + type: string + type: object required: - card - - template type: object type: array type: object diff --git a/controllers/lmes/driver/driver.go b/controllers/lmes/driver/driver.go index b98c949f..c44aaa05 100644 --- a/controllers/lmes/driver/driver.go +++ b/controllers/lmes/driver/driver.go @@ -20,6 +20,7 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "io" "io/fs" @@ -53,17 +54,19 @@ const ( ) type DriverOption struct { - Context context.Context - OutputPath string - DetectDevice bool - TaskRecipesPath string - TaskRecipes []string - CatalogPath string - CustomCards []string - Logger logr.Logger - Args []string - CommPort int - DownloadAssetsS3 bool + Context context.Context + OutputPath string + DetectDevice bool + TaskRecipesPath string + TaskRecipes []string + CatalogPath string + CustomCards []string + CustomTemplates []string + CustomSystemPrompt []string + Logger logr.Logger + Args []string + CommPort int + DownloadAssetsS3 bool } type Driver interface { @@ -321,6 +324,10 @@ func (d *driverImpl) exec() error { return fmt.Errorf("failed to create task recipes: %v", err) } + if err := d.prepDir4CustomArtifacts(); err != nil { + return fmt.Errorf("failed to create the directories for custom artifacts: %v", err) + } + if err := d.createCustomCards(); err != nil { return fmt.Errorf("failed to create custom cards: %v", err) } @@ -329,6 +336,12 @@ func (d *driverImpl) exec() error { if err := d.downloadS3Assets(); err != nil { return err } + if err := d.createCustomTemplates(); err != nil { + return fmt.Errorf("failed to create custom templates: %v", err) + } + if err := d.createCustomSystemPrompts(); err != nil { + return fmt.Errorf("failed to create custom system_prompts: %v", err) + } // Detect available devices if needed if err := d.detectDevice(); err != nil { @@ -491,6 +504,15 @@ func (d *driverImpl) createTaskRecipes() error { return nil } +func (d *driverImpl) prepDir4CustomArtifacts() error { + subDirs := []string{"cards", "templates", "system_prompts"} + var errs []error + for _, dir := range subDirs { + errs = append(errs, mkdirIfNotExist(filepath.Join(d.Option.CatalogPath, dir))) + } + return errors.Join(errs...) +} + func (d *driverImpl) createCustomCards() error { for i, customCard := range d.Option.CustomCards { err := os.WriteFile( @@ -504,3 +526,44 @@ func (d *driverImpl) createCustomCards() error { } return nil } + +func (d *driverImpl) createCustomTemplates() error { + for _, customTemplate := range d.Option.CustomTemplates { + values := strings.SplitN(customTemplate, "|", 2) + err := os.WriteFile( + filepath.Join(d.Option.CatalogPath, "templates", fmt.Sprintf("%s.json", values[0])), + []byte(values[1]), + 0666, + ) + if err != nil { + return err + } + } + return nil +} + +func (d *driverImpl) createCustomSystemPrompts() error { + for _, systemPrompt := range d.Option.CustomSystemPrompt { + values := strings.SplitN(systemPrompt, "|", 2) + err := os.WriteFile( + filepath.Join(d.Option.CatalogPath, "system_prompts", fmt.Sprintf("%s.json", values[0])), + []byte(fmt.Sprintf(`{ "__type__": "textual_system_prompt", "text": "%s" }`, values[1])), + 0666, + ) + if err != nil { + return err + } + } + return nil +} + +func mkdirIfNotExist(path string) error { + fi, err := os.Stat(path) + if err == nil && !fi.IsDir() { + return fmt.Errorf("%s is a file. can not create a directory", path) + } + if os.IsNotExist(err) { + return os.MkdirAll(path, 0770) + } + return nil +} diff --git a/controllers/lmes/driver/driver_test.go b/controllers/lmes/driver/driver_test.go index 63e70cc8..15c32d21 100644 --- a/controllers/lmes/driver/driver_test.go +++ b/controllers/lmes/driver/driver_test.go @@ -19,8 +19,10 @@ package driver import ( "context" "flag" + "fmt" "math/rand" "os" + "path/filepath" "testing" "time" @@ -46,10 +48,39 @@ func TestMain(m *testing.M) { m.Run() } -// to support concurrent testing if needed, each test needs a -// dedicated port -func genRandomPort() int { - return rand.Intn(1000) + 18080 +type testInfo struct { + outputPath string + catalogPath string + taskPath string + port int + tearDown func(*testing.T) +} + +func setupTest(t *testing.T, hasOutput bool) testInfo { + folderSuffix := rand.Intn(100) + outputPath := fmt.Sprintf("outputs%d", folderSuffix) + catalogPath := fmt.Sprintf("mycatalogs%d", folderSuffix) + taskPath := fmt.Sprintf("mytasks%d", folderSuffix) + os.Mkdir(outputPath, 0750) + os.Mkdir(catalogPath, 0750) + os.Mkdir(taskPath, 0750) + + testInfo := testInfo{ + outputPath: outputPath, + catalogPath: catalogPath, + taskPath: taskPath, + port: rand.Intn(1000) + 18080, + tearDown: func(t *testing.T) { + if hasOutput { + assert.Nil(t, os.Remove(filepath.Join(outputPath, "stderr.log"))) + assert.Nil(t, os.Remove(filepath.Join(outputPath, "stdout.log"))) + } + assert.Nil(t, os.RemoveAll(taskPath)) + assert.Nil(t, os.RemoveAll(outputPath)) + assert.Nil(t, os.RemoveAll(catalogPath)) + }, + } + return testInfo } func runDriverAndWait4Complete(t *testing.T, driver Driver, returnError bool) (progressMsgs []string, results string) { @@ -77,29 +108,33 @@ func runDriverAndWait4Complete(t *testing.T, driver Driver, returnError bool) (p } func Test_Driver(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "echo tttttttttttttttttttt"}, - CommPort: genRandomPort(), + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "echo tttttttttttttttttttt"}, + CommPort: info.port, }) assert.Nil(t, err) - runDriverAndWait4Complete(t, driver, false) - assert.Nil(t, driver.Shutdown()) - assert.Nil(t, os.Remove("./stderr.log")) - assert.Nil(t, os.Remove("./stdout.log")) } func Test_Wait4Shutdown(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "echo test"}, - CommPort: genRandomPort(), + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "echo test"}, + CommPort: info.port, }) assert.Nil(t, err) @@ -115,18 +150,19 @@ func Test_Wait4Shutdown(t *testing.T) { _, err = driver.GetStatus() assert.ErrorContains(t, err, "connection refused") - - assert.Nil(t, os.Remove("./stderr.log")) - assert.Nil(t, os.Remove("./stdout.log")) } func Test_ProgressUpdate(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"}, - CommPort: genRandomPort(), + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"}, + CommPort: info.port, }) assert.Nil(t, err) @@ -139,18 +175,20 @@ func Test_ProgressUpdate(t *testing.T) { }, msgs) assert.Nil(t, driver.Shutdown()) - assert.Nil(t, os.Remove("./stderr.log")) - assert.Nil(t, os.Remove("./stdout.log")) } func Test_DetectDeviceError(t *testing.T) { + info := setupTest(t, false) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ Context: context.Background(), - OutputPath: ".", + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, DetectDevice: true, Logger: driverLog, Args: []string{"sh", "-ec", "python -m lm_eval --output_path ./output --model test --model_args arg1=value1 --tasks task1,task2"}, - CommPort: genRandomPort(), + CommPort: info.port, }) assert.Nil(t, err) @@ -162,18 +200,22 @@ func Test_DetectDeviceError(t *testing.T) { assert.Nil(t, driver.Shutdown()) // the following files don't exist for this case - assert.NotNil(t, os.Remove("./stderr.log")) - assert.NotNil(t, os.Remove("./stdout.log")) + assert.NotNil(t, os.Remove(filepath.Join(info.outputPath, "stderr.log"))) + assert.NotNil(t, os.Remove(filepath.Join(info.outputPath, "stdout.log"))) } func Test_DownloadAssetsS3Error(t *testing.T) { + info := setupTest(t, false) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ Context: context.Background(), - OutputPath: ".", + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, DetectDevice: false, Logger: driverLog, Args: []string{"sh", "-ec", "python -m lm_eval --output_path ./output --model test --model_args arg1=value1 --tasks task1,task2"}, - CommPort: genRandomPort(), + CommPort: info.port, DownloadAssetsS3: true, }) assert.Nil(t, err) @@ -220,17 +262,21 @@ func Test_PatchDevice(t *testing.T) { } func Test_TaskRecipes(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ Context: context.Background(), - OutputPath: ".", + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, Logger: driverLog, - TaskRecipesPath: "./", + TaskRecipesPath: info.taskPath, TaskRecipes: []string{ "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", }, Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"}, - CommPort: genRandomPort(), + CommPort: info.port, }) assert.Nil(t, err) @@ -244,44 +290,48 @@ func Test_TaskRecipes(t *testing.T) { assert.Nil(t, driver.Shutdown()) - tr0, err := os.ReadFile("./tr_0.yaml") + tr0, err := os.ReadFile(filepath.Join(info.taskPath, "tr_0.yaml")) assert.Nil(t, err) assert.Equal(t, "task: tr_0\ninclude: unitxt\nrecipe: card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", string(tr0), ) - tr1, err := os.ReadFile("./tr_1.yaml") + tr1, err := os.ReadFile(filepath.Join(info.taskPath, "tr_1.yaml")) assert.Nil(t, err) assert.Equal(t, "task: tr_1\ninclude: unitxt\nrecipe: card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", string(tr1), ) - assert.Nil(t, os.Remove("./stderr.log")) - assert.Nil(t, os.Remove("./stdout.log")) - assert.Nil(t, os.Remove("./tr_0.yaml")) - assert.Nil(t, os.Remove("./tr_1.yaml")) } func Test_CustomCards(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ Context: context.Background(), - OutputPath: ".", + OutputPath: info.outputPath, Logger: driverLog, - TaskRecipesPath: "./", - CatalogPath: "./", + TaskRecipesPath: info.taskPath, + CatalogPath: info.catalogPath, TaskRecipes: []string{ "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", + "card=cards.unitxt.card1,template=templates.tp_0,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", }, CustomCards: []string{ `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, }, + CustomTemplates: []string{ + `tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + }, + CustomSystemPrompt: []string{ + "sp_0|this is a custom system prompt", + }, Args: []string{"sh", "-ec", "sleep 1; echo 'testing progress: 100%|' >&2; sleep 3"}, - CommPort: genRandomPort(), + CommPort: info.port, }) assert.Nil(t, err) - os.Mkdir("cards", 0750) - msgs, _ := runDriverAndWait4Complete(t, driver, false) assert.Equal(t, []string{ @@ -292,32 +342,49 @@ func Test_CustomCards(t *testing.T) { assert.Nil(t, driver.Shutdown()) - tr0, err := os.ReadFile("./tr_0.yaml") + tr0, err := os.ReadFile(filepath.Join(info.taskPath, "tr_0.yaml")) assert.Nil(t, err) assert.Equal(t, "task: tr_0\ninclude: unitxt\nrecipe: card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", string(tr0), ) - custom0, err := os.ReadFile("./cards/custom_0.json") + tr1, err := os.ReadFile(filepath.Join(info.taskPath, "tr_1.yaml")) + assert.Nil(t, err) + assert.Equal(t, + "task: tr_1\ninclude: unitxt\nrecipe: card=cards.unitxt.card1,template=templates.tp_0,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", + string(tr1), + ) + custom0, err := os.ReadFile(filepath.Join(info.catalogPath, "cards", "custom_0.json")) assert.Nil(t, err) assert.Equal(t, `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, string(custom0), ) - assert.Nil(t, os.Remove("./stderr.log")) - assert.Nil(t, os.Remove("./stdout.log")) - assert.Nil(t, os.Remove("./tr_0.yaml")) - assert.Nil(t, os.Remove("./cards/custom_0.json")) - assert.Nil(t, os.Remove("./cards")) + template0, err := os.ReadFile(filepath.Join(info.catalogPath, "templates", "tp_0.json")) + assert.Nil(t, err) + assert.Equal(t, + `{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + string(template0), + ) + prompt0, err := os.ReadFile(filepath.Join(info.catalogPath, "system_prompts", "sp_0.json")) + assert.Nil(t, err) + assert.Equal(t, + `{ "__type__": "textual_system_prompt", "text": "this is a custom system prompt" }`, + string(prompt0), + ) } func Test_ProgramError(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "sleep 1; exit 1"}, - CommPort: genRandomPort(), + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "sleep 1; exit 1"}, + CommPort: info.port, }) assert.Nil(t, err) @@ -329,7 +396,4 @@ func Test_ProgramError(t *testing.T) { }, msgs) assert.Nil(t, driver.Shutdown()) - - assert.Nil(t, os.Remove("./stderr.log")) - assert.Nil(t, os.Remove("./stdout.log")) } diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index e7f21ab0..569392b1 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -221,7 +221,8 @@ func (r *LMEvalJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( } } } - log.Info("Continuing after PVC") + + log.Info("Checking the job state") // Handle the job based on its state switch job.Status.State { @@ -415,15 +416,18 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, // Validate the custom card if exists // FIXME: Move the validation to the webhook once we enable it. - if err := r.validateCustomCard(job, log); err != nil { + if err := r.validateCustomRecipes(job, log); err != nil { // custom card validation failed job.Status.State = lmesv1alpha1.CompleteJobState job.Status.Reason = lmesv1alpha1.FailedReason job.Status.Message = err.Error() + // also update the complete time + current := v1.Now() + job.Status.CompleteTime = ¤t if err := r.Status().Update(ctx, job); err != nil { - log.Error(err, "unable to update LMEvalJob status for custom card validation error") + log.Error(err, "unable to update LMEvalJob status for custom recipe validation error") } - log.Error(err, "Contain invalid custom card in the LMEvalJob", "name", job.Name) + log.Error(err, "Contain invalid custom recipe in the LMEvalJob", "name", job.Name) return ctrl.Result{}, err } @@ -523,6 +527,10 @@ func (r *LMEvalJobReconciler) getPod(ctx context.Context, job *lmesv1alpha1.LMEv } func (r *LMEvalJobReconciler) deleteJobPod(ctx context.Context, job *lmesv1alpha1.LMEvalJob) error { + if job.Status.PodName == "" { + return nil + } + pod := corev1.Pod{ TypeMeta: v1.TypeMeta{ Kind: "Pod", @@ -644,15 +652,17 @@ func (r *LMEvalJobReconciler) handleResume(ctx context.Context, log logr.Logger, return ctrl.Result{}, err } -func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, log logr.Logger) error { +func (r *LMEvalJobReconciler) validateCustomRecipes(job *lmesv1alpha1.LMEvalJob, log logr.Logger) error { if job.Spec.TaskList.TaskRecipes == nil { return nil } + var checkedTemplates []string + customArtifacts := job.Spec.TaskList.CustomArtifacts for _, taskRecipe := range job.Spec.TaskList.TaskRecipes { if taskRecipe.Card.Custom != "" { - var card map[string]interface{} - if err := json.Unmarshal([]byte(taskRecipe.Card.Custom), &card); err != nil { + card, err := unmarshal(taskRecipe.Card.Custom) + if err != nil { log.Error(err, "failed to parse the custom card") return fmt.Errorf("custom card is not a valid JSON string, %s", err.Error()) } @@ -663,11 +673,61 @@ func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, lo return missKeyError } } + + // validate the template JSON + if taskRecipe.Template != nil && taskRecipe.Template.Ref != "" && + !slices.Contains(checkedTemplates, taskRecipe.Template.Ref) { + + custom := getCustomArtifactByName(customArtifacts.GetTemplates(), taskRecipe.Template.Ref) + if custom == nil { + return fmt.Errorf("the reference name of the custom template is not defined: %s", taskRecipe.Template.Ref) + } + + template, err := unmarshal(custom.Value) + if err != nil { + log.Error(err, "failed to parse the custom template") + return fmt.Errorf("custom template is not a valid JSON string, %s", err.Error()) + } + // two mandatory fields: input_format and output_format + for _, fieldName := range []string{"input_format", "output_format"} { + if _, ok := template[fieldName]; !ok { + missKeyError := fmt.Errorf("no %s definition in the custom template", fieldName) + log.Error(missKeyError, "failed to parse the custom template") + return missKeyError + } + } + checkedTemplates = append(checkedTemplates, taskRecipe.Template.Ref) + } + + // only check if the system prompt is defined in the custom section or not + if taskRecipe.SystemPrompt != nil && taskRecipe.SystemPrompt.Ref != "" { + if getCustomArtifactByName(customArtifacts.GetSystemPrompts(), taskRecipe.SystemPrompt.Ref) == nil { + return fmt.Errorf("the reference name of the custom system prompt is not defined: %s", taskRecipe.SystemPrompt.Ref) + } + } } return nil } +func getCustomArtifactByName(customs []lmesv1alpha1.CustomArtifact, name string) *lmesv1alpha1.CustomArtifact { + if len(customs) == 0 { + return nil + } + for _, custom := range customs { + if custom.Name == name { + return &custom + } + } + return nil +} + +func unmarshal(custom string) (map[string]interface{}, error) { + var obj map[string]interface{} + err := json.Unmarshal([]byte(custom), &obj) + return obj, err +} + func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { var envVars = removeProtectedEnvVars(job.Spec.Pod.GetContainer().GetEnv()) @@ -1187,20 +1247,34 @@ func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string cmds = append(cmds, "--listen-port", fmt.Sprintf("%d", svcOpts.DriverPort)) } + if svcOpts.DriverPort != 0 && svcOpts.DriverPort != driver.DefaultPort { + cmds = append(cmds, "--listen-port", fmt.Sprintf("%d", svcOpts.DriverPort)) + } + cr_idx := 0 for _, recipe := range job.Spec.TaskList.TaskRecipes { - if recipe.Card.Name != "" { - // built-in card, regular recipe - cmds = append(cmds, "--task-recipe", recipe.String()) - } else if recipe.Card.Custom != "" { + // duplicate the TaskRecipe and update its content to generate proper recipe string + dupRecipe := recipe.DeepCopy() + + if recipe.Card.Custom != "" { // custom card, need to inject --custom-card arg as well - dupRecipe := recipe.DeepCopy() // the format of a custom card's name: custom_ dupRecipe.Card.Name = fmt.Sprintf("cards.%s_%d", driver.CustomCardPrefix, cr_idx) - cmds = append(cmds, "--task-recipe", dupRecipe.String()) cmds = append(cmds, "--custom-card", dupRecipe.Card.Custom) cr_idx++ } + + cmds = append(cmds, "--task-recipe", dupRecipe.String()) + } + + // go through custom artificats and add corresponding arguments + if job.Spec.TaskList.CustomArtifacts != nil { + for _, template := range job.Spec.TaskList.CustomArtifacts.GetTemplates() { + cmds = append(cmds, "--custom-template", template.String()) + } + for _, prompt := range job.Spec.TaskList.CustomArtifacts.GetSystemPrompts() { + cmds = append(cmds, "--custom-prompt", prompt.String()) + } } cmds = append(cmds, "--") diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index 1e9e137e..6b0e72f1 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -906,7 +906,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { TaskRecipes: []lmesv1alpha1.TaskRecipe{ { Card: lmesv1alpha1.Card{Name: "unitxt.card1"}, - Template: "unitxt.template", + Template: &lmesv1alpha1.Template{Name: "unitxt.template"}, Format: &format, Metrics: []string{"unitxt.metric1", "unitxt.metric2"}, NumDemos: &numDemos, @@ -933,7 +933,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, lmesv1alpha1.TaskRecipe{ Card: lmesv1alpha1.Card{Name: "unitxt.card2"}, - Template: "unitxt.template2", + Template: &lmesv1alpha1.Template{Name: "unitxt.template2"}, Format: &format, Metrics: []string{"unitxt.metric3", "unitxt.metric4"}, NumDemos: &numDemos, @@ -991,7 +991,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { Card: lmesv1alpha1.Card{ Custom: `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, }, - Template: "unitxt.template", + Template: &lmesv1alpha1.Template{Name: "unitxt.template"}, Format: &format, Metrics: []string{"unitxt.metric1", "unitxt.metric2"}, NumDemos: &numDemos, @@ -1010,8 +1010,159 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", "--output-path", "/opt/app-root/src/output", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--", + }, generateCmd(svcOpts, job)) + + // add second task using custom recipe + custom template + job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, + lmesv1alpha1.TaskRecipe{ + Card: lmesv1alpha1.Card{ + Custom: `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + }, + Template: &lmesv1alpha1.Template{ + Ref: "tp_0", + }, + Format: &format, + Metrics: []string{"unitxt.metric3", "unitxt.metric4"}, + NumDemos: &numDemos, + DemosPoolSize: &demosPoolSize, + }, + ) + + job.Spec.TaskList.CustomArtifacts = &lmesv1alpha1.CustomArtifacts{ + Templates: []lmesv1alpha1.CustomArtifact{ + { + Name: "tp_0", + Value: `{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + }, + }, + } + + assert.Equal(t, []string{ + "/opt/app-root/src/bin/driver", + "--output-path", "/opt/app-root/src/output", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_1,template=templates.tp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-template", `tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + "--", + }, generateCmd(svcOpts, job)) + + // add third task using normal card + custom system_prompt + job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, + lmesv1alpha1.TaskRecipe{ + Card: lmesv1alpha1.Card{Name: "unitxt.card"}, + SystemPrompt: &lmesv1alpha1.SystemPrompt{ + Ref: "sp_0", + }, + Format: &format, + Metrics: []string{"unitxt.metric4", "unitxt.metric5"}, + NumDemos: &numDemos, + DemosPoolSize: &demosPoolSize, + }, + ) + + job.Spec.TaskList.CustomArtifacts.SystemPrompts = []lmesv1alpha1.CustomArtifact{ + { + Name: "sp_0", + Value: "this is a custom system promp", + }, + } + + assert.Equal(t, []string{ + "/opt/app-root/src/bin/driver", + "--output-path", "/opt/app-root/src/output", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_1,template=templates.tp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--task-recipe", "card=unitxt.card,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric4,unitxt.metric5],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-template", `tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + "--custom-prompt", "sp_0|this is a custom system promp", + "--", + }, generateCmd(svcOpts, job)) + + // add forth task using custom card + custom template + custom system_prompt + // and reuse the template and system prompt + job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, + lmesv1alpha1.TaskRecipe{ + Card: lmesv1alpha1.Card{ + Custom: `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + }, + Template: &lmesv1alpha1.Template{ + Ref: "tp_0", + }, + SystemPrompt: &lmesv1alpha1.SystemPrompt{ + Ref: "sp_0", + }, + Format: &format, + Metrics: []string{"unitxt.metric6", "unitxt.metric7"}, + NumDemos: &numDemos, + DemosPoolSize: &demosPoolSize, + }, + ) + + assert.Equal(t, []string{ + "/opt/app-root/src/bin/driver", + "--output-path", "/opt/app-root/src/output", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_1,template=templates.tp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--task-recipe", "card=unitxt.card,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric4,unitxt.metric5],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_2,template=templates.tp_0,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric6,unitxt.metric7],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-template", `tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + "--custom-prompt", "sp_0|this is a custom system promp", + "--", + }, generateCmd(svcOpts, job)) + + // add fifth task using regular card + custom template + custom system_prompt + // both template and system prompt are new + job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, + lmesv1alpha1.TaskRecipe{ + Card: lmesv1alpha1.Card{Name: "unitxt.card2"}, + Template: &lmesv1alpha1.Template{ + Ref: "tp_1", + }, + SystemPrompt: &lmesv1alpha1.SystemPrompt{ + Ref: "sp_1", + }, + Format: &format, + Metrics: []string{"unitxt.metric6", "unitxt.metric7"}, + NumDemos: &numDemos, + DemosPoolSize: &demosPoolSize, + }, + ) + + job.Spec.TaskList.CustomArtifacts.Templates = append(job.Spec.TaskList.CustomArtifacts.Templates, lmesv1alpha1.CustomArtifact{ + Name: "tp_1", + Value: `{ "__type__": "input_output_template", "instruction": "2In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + }) + + job.Spec.TaskList.CustomArtifacts.SystemPrompts = append(job.Spec.TaskList.CustomArtifacts.SystemPrompts, lmesv1alpha1.CustomArtifact{ + Name: "sp_1", + Value: "this is a custom system promp2", + }) + + assert.Equal(t, []string{ + "/opt/app-root/src/bin/driver", + "--output-path", "/opt/app-root/src/output", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_1,template=templates.tp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--task-recipe", "card=unitxt.card,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric4,unitxt.metric5],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--task-recipe", "card=cards.custom_2,template=templates.tp_0,system_prompt=system_prompts.sp_0,metrics=[unitxt.metric6,unitxt.metric7],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--task-recipe", "card=unitxt.card2,template=templates.tp_1,system_prompt=system_prompts.sp_1,metrics=[unitxt.metric6,unitxt.metric7],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--custom-template", `tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + "--custom-template", `tp_1|{ "__type__": "input_output_template", "instruction": "2In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, + "--custom-prompt", "sp_0|this is a custom system promp", + "--custom-prompt", "sp_1|this is a custom system promp2", "--", }, generateCmd(svcOpts, job)) } @@ -1048,7 +1199,7 @@ func Test_CustomCardValidation(t *testing.T) { }, } - assert.ErrorContains(t, lmevalRec.validateCustomCard(job, log), "custom card is not a valid JSON string") + assert.ErrorContains(t, lmevalRec.validateCustomRecipes(job, log), "custom card is not a valid JSON string") // no loader job.Spec.TaskList.TaskRecipes[0].Card.Custom = ` @@ -1076,7 +1227,7 @@ func Test_CustomCardValidation(t *testing.T) { "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }` - assert.ErrorContains(t, lmevalRec.validateCustomCard(job, log), "no loader definition in the custom card") + assert.ErrorContains(t, lmevalRec.validateCustomRecipes(job, log), "no loader definition in the custom card") // ok job.Spec.TaskList.TaskRecipes[0].Card.Custom = ` @@ -1110,14 +1261,66 @@ func Test_CustomCardValidation(t *testing.T) { "templates": "templates.translation.directed.all" }` - assert.Nil(t, lmevalRec.validateCustomCard(job, log)) + assert.Nil(t, lmevalRec.validateCustomRecipes(job, log)) + + job.Spec.TaskList.TaskRecipes[0].Template = &lmesv1alpha1.Template{ + Ref: "tp_0", + } + + // missing custom template + assert.ErrorContains(t, lmevalRec.validateCustomRecipes(job, log), "the reference name of the custom template is not defined: tp_0") + + job.Spec.TaskList.CustomArtifacts = &lmesv1alpha1.CustomArtifacts{ + Templates: []lmesv1alpha1.CustomArtifact{ + { + Name: "tp_0", + Value: ` + { + "__type__": "input_output_template", + "instruction": "In the following task, you translate a {text_type}.", + "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", + "target_prefix": "Translation: ", + "output_format": "{translation}", + "postprocessors": [ + "processors.lower_case" + ] + } + `, + }, + }, + } + + // pass + assert.Nil(t, lmevalRec.validateCustomRecipes(job, log)) + + job.Spec.TaskList.CustomArtifacts.Templates = append(job.Spec.TaskList.CustomArtifacts.Templates, lmesv1alpha1.CustomArtifact{ + Name: "tp_1", + Value: ` + { + "__type__": "input_output_template", + "instruction": "In the following task, you translate a {text_type}.", + "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", + "target_prefix": "Translation: ", + "postprocessors": [ + "processors.lower_case" + ] + } + `, + }) + + job.Spec.TaskList.TaskRecipes[0].Template = &lmesv1alpha1.Template{ + Ref: "tp_1", + } + + // missing outout_format property + assert.ErrorContains(t, lmevalRec.validateCustomRecipes(job, log), "no output_format definition in the custom template") } func Test_ConcatTasks(t *testing.T) { tasks := concatTasks(lmesv1alpha1.TaskList{ TaskNames: []string{"task1", "task2"}, TaskRecipes: []lmesv1alpha1.TaskRecipe{ - {Template: "template3", Card: lmesv1alpha1.Card{Name: "format3"}}, + {Template: &lmesv1alpha1.Template{Name: "template3"}, Card: lmesv1alpha1.Card{Name: "format3"}}, }, })