Skip to content

Commit

Permalink
feat(LMES): Support custom template and prompt (trustyai-explainabili…
Browse files Browse the repository at this point in the history
…ty#404)

Expand the LMEvalJob CRD to support custom templates and system prompts.
This is mainly for custom unitxt task recipes. Now, users can use the
`template` and `systemPrompt` fields under the `taskRecipes` to specify
the custom template and system prompt.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang authored Feb 19, 2025
1 parent bf9d8fd commit 6551b6a
Show file tree
Hide file tree
Showing 9 changed files with 748 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.lmes-job
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ RUN curl -L https://github.com/opendatahub-io/lm-evaluation-harness/archive/refs

RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("class: !function " + task.__file__.replace("task.py", "task.Unitxt"))' > ./my_tasks/unitxt

ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src:/opt/app-root/src/server
ENV HF_HOME=/opt/app-root/src/hf_home
ENV UNITXT_CATALOGS=/opt/app-root/src/my_catalogs

Expand Down
79 changes: 77 additions & 2 deletions api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,69 @@ type Card struct {
Custom string `json:"custom,omitempty"`
}

type Template struct {
// Unitxt template ID
// +optional
Name string `json:"name,omitempty"`
// The name of the custom template in the custom field. Its value is a JSON string
// for a custom Unitxt template. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
// to compose a custom template, store it as a JSON file by calling the
// add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
// and use the JSON content as the value here.
// +optional
Ref string `json:"ref,omitempty"`
}

type SystemPrompt struct {
// Unitxt System Prompt id
Name string `json:"name,omitempty"`
// The name of the custom systemPrompt in the custom field. Its value is a custom system prompt string
Ref string `json:"ref,omitempty"`
}

type CustomArtifact struct {
// Name of the custom artifact
Name string `json:"name"`
// Value of the custom artifact. It could be a JSON string or plain text
// depending on the artifact type
Value string `json:"value"`
}

func (c *CustomArtifact) String() string {
return fmt.Sprintf("%s|%s", c.Name, c.Value)
}

type CustomArtifacts struct {
Templates []CustomArtifact `json:"templates,omitempty"`
SystemPrompts []CustomArtifact `json:"systemPrompts,omitempty"`
}

func (c *CustomArtifacts) GetTemplates() []CustomArtifact {
if c == nil {
return nil
}
return c.Templates
}

func (c *CustomArtifacts) GetSystemPrompts() []CustomArtifact {
if c == nil {
return nil
}
return c.SystemPrompts
}

// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
// Find details of the Unitxt Recipe here:
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
type TaskRecipe struct {
// The Unitxt dataset card
Card Card `json:"card"`
// The Unitxt template
Template string `json:"template"`
// +optional
Template *Template `json:"template,omitempty"`
// The Unitxt System Prompt
// +optional
SystemPrompt *SystemPrompt `json:"systemPrompt,omitempty"`
// The Unitxt Task
// +optional
Task *string `json:"task,omitempty"`
Expand All @@ -109,11 +164,31 @@ type TaskList struct {
TaskNames []string `json:"taskNames,omitempty"`
// Task Recipes specifically for Unitxt
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
// Custom Unitxt artifacts that can be used in a TaskRecipe
CustomArtifacts *CustomArtifacts `json:"custom,omitempty"`
}

// Use the tp_idx and sp_idx to point to the corresponding custom template
// and custom system_prompt
func (t *TaskRecipe) String() string {
var b strings.Builder
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card.Name, t.Template))
b.WriteString(fmt.Sprintf("card=%s", t.Card.Name))
if t.Template != nil {
if t.Template.Name != "" {
b.WriteString(fmt.Sprintf(",template=%s", t.Template.Name))
} else {
// refer to a custom template. add "templates." prefix
b.WriteString(fmt.Sprintf(",template=templates.%s", t.Template.Ref))
}
}
if t.SystemPrompt != nil {
if t.SystemPrompt.Name != "" {
b.WriteString(fmt.Sprintf(",system_prompt=%s", t.SystemPrompt.Name))
} else {
// refer to custom system prompt. add "system_prompts." prefix
b.WriteString(fmt.Sprintf(",system_prompt=system_prompts.%s", t.SystemPrompt.Ref))
}
}
if t.Task != nil {
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
}
Expand Down
85 changes: 85 additions & 0 deletions api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 25 additions & 19 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,25 @@ func (t *strArrayArg) String() string {
}

var (
taskRecipes strArrayArg
customCards strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
getStatus = flag.Bool("get-status", false, "Get current status")
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
driverLog = ctrl.Log.WithName("driver")
taskRecipes strArrayArg
customCards strArrayArg
customTemplates strArrayArg
customSystemPrompts strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
getStatus = flag.Bool("get-status", false, "Get current status")
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
driverLog = ctrl.Log.WithName("driver")
)

func init() {
flag.Var(&taskRecipes, "task-recipe", "task recipe")
flag.Var(&customCards, "custom-card", "A JSON string represents a custom card")
flag.Var(&customTemplates, "custom-template", "A JSON string represents a custom template")
flag.Var(&customSystemPrompts, "custom-prompt", "A string represents a custom system_prompt")
}

func main() {
Expand Down Expand Up @@ -107,15 +111,17 @@ func main() {
}

driverOpt := driver.DriverOption{
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
Args: args,
CommPort: *commPort,
DownloadAssetsS3: *downloadAssetsS3,
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
CustomTemplates: customTemplates,
CustomSystemPrompt: customSystemPrompts,
Args: args,
CommPort: *commPort,
DownloadAssetsS3: *downloadAssetsS3,
}

driver, err := driver.NewDriver(&driverOpt)
Expand Down
63 changes: 61 additions & 2 deletions config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4721,6 +4721,42 @@ spec:
taskList:
description: Evaluation task list
properties:
custom:
description: Custom Unitxt artifacts that can be used in a TaskRecipe
properties:
systemPrompts:
items:
properties:
name:
description: Name of the custom artifact
type: string
value:
description: |-
Value of the custom artifact. It could be a JSON string or plain text
depending on the artifact type
type: string
required:
- name
- value
type: object
type: array
templates:
items:
properties:
name:
description: Name of the custom artifact
type: string
value:
description: |-
Value of the custom artifact. It could be a JSON string or plain text
depending on the artifact type
type: string
required:
- name
- value
type: object
type: array
type: object
taskNames:
description: TaskNames from lm-eval's task list
items:
Expand Down Expand Up @@ -4764,15 +4800,38 @@ spec:
numDemos:
description: Number of fewshot
type: integer
systemPrompt:
description: The Unitxt System Prompt
properties:
name:
description: Unitxt System Prompt id
type: string
ref:
description: The name of the custom systemPrompt in
the custom field. Its value is a custom system prompt
string
type: string
type: object
task:
description: The Unitxt Task
type: string
template:
description: The Unitxt template
type: string
properties:
name:
description: Unitxt template ID
type: string
ref:
description: |-
The name of the custom template in the custom field. Its value is a JSON string
for a custom Unitxt template. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
to compose a custom template, store it as a JSON file by calling the
add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
and use the JSON content as the value here.
type: string
type: object
required:
- card
- template
type: object
type: array
type: object
Expand Down
Loading

0 comments on commit 6551b6a

Please sign in to comment.