Skip to content

Commit a1a0cb8

Browse files
committed
feat: add support for tools context
Signed-off-by: Donnie Adams <[email protected]>
1 parent 4e41972 commit a1a0cb8

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

client_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,39 @@ func TestSimpleEvaluate(t *testing.T) {
8585
}
8686
}
8787

88+
func TestEvaluateWithContext(t *testing.T) {
89+
wd, err := os.Getwd()
90+
if err != nil {
91+
t.Fatalf("Error getting current working directory: %v", err)
92+
}
93+
fmt.Println(wd)
94+
95+
tool := &ToolDef{
96+
Instructions: "What is the capital of the united states?",
97+
Context: []string{
98+
wd + "/test/acorn-labs-context.gpt",
99+
},
100+
}
101+
102+
run, err := client.Evaluate(context.Background(), Opts{DisableCache: true, IncludeEvents: true}, tool)
103+
if err != nil {
104+
t.Errorf("Error executing tool: %v", err)
105+
}
106+
107+
for event := range run.Events() {
108+
fmt.Println(event)
109+
}
110+
111+
out, err := run.Text()
112+
if err != nil {
113+
t.Errorf("Error reading output: %v", err)
114+
}
115+
116+
if out != "Acorn Labs" {
117+
t.Errorf("Unexpected output: %s", out)
118+
}
119+
}
120+
88121
func TestRunFileChdir(t *testing.T) {
89122
wd, err := os.Getwd()
90123
if err != nil {

test/acorn-labs-context.gpt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Say "Answer every query with 'Acorn Labs'"

tool.go

+12
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ func (t *ToolDef) String() string {
5252
if t.Cache != nil && !*t.Cache {
5353
sb.WriteString("Cache: false\n")
5454
}
55+
if len(t.Context) > 0 {
56+
sb.WriteString(fmt.Sprintf("Context: %s\n", strings.Join(t.Context, ", ")))
57+
}
58+
if len(t.ExportContext) > 0 {
59+
sb.WriteString(fmt.Sprintf("Export Context: %s\n", strings.Join(t.ExportContext, ", ")))
60+
}
61+
if len(t.Export) > 0 {
62+
sb.WriteString(fmt.Sprintf("Export: %s\n", strings.Join(t.Export, ", ")))
63+
}
64+
if len(t.GlobalTools) > 0 {
65+
sb.WriteString(fmt.Sprintf("Global Tools: %s\n", strings.Join(t.GlobalTools, ", ")))
66+
}
5567
if t.Temperature != nil {
5668
sb.WriteString(fmt.Sprintf("Temperature: %f\n", *t.Temperature))
5769
}

0 commit comments

Comments
 (0)