diff --git a/integration/cred_test.go b/integration/cred_test.go index 1ea73d35..a32e50e9 100644 --- a/integration/cred_test.go +++ b/integration/cred_test.go @@ -14,22 +14,6 @@ func TestGPTScriptCredential(t *testing.T) { require.Contains(t, out, "CREDENTIAL") } -// TestCredentialScopes makes sure that environment variables set by credential tools and shared credential tools -// are only available to the correct tools. See scripts/credscopes.gpt for more details. -func TestCredentialScopes(t *testing.T) { - out, err := RunScript("scripts/cred_scopes.gpt", "--sub-tool", "oneOne") - require.NoError(t, err) - require.Contains(t, out, "good") - - out, err = RunScript("scripts/cred_scopes.gpt", "--sub-tool", "twoOne") - require.NoError(t, err) - require.Contains(t, out, "good") - - out, err = RunScript("scripts/cred_scopes.gpt", "--sub-tool", "twoTwo") - require.NoError(t, err) - require.Contains(t, out, "good") -} - // TestCredentialExpirationEnv tests a GPTScript with two credentials that expire at different times. // One expires after two hours, and the other expires after one hour. // This test makes sure that the GPTSCRIPT_CREDENTIAL_EXPIRATION environment variable is set to the nearer expiration time (1h). diff --git a/integration/scripts/cred_scopes.gpt b/integration/scripts/cred_scopes.gpt deleted file mode 100644 index dc8e24e7..00000000 --- a/integration/scripts/cred_scopes.gpt +++ /dev/null @@ -1,160 +0,0 @@ -# This script sets up a chain of tools in a tree structure. -# The root is oneOne, with children twoOne and twoTwo, with children threeOne, threeTwo, and threeThree, with only -# threeTwo shared between them. -# Each tool should only have access to any credentials it defines and any credentials exported/shared by its -# immediate children (but not grandchildren). -# This script checks to make sure that this is working properly. -name: oneOne -tools: twoOne, twoTwo -cred: getcred with oneOne as var and 11 as val - -#!python3 - -import os - -oneOne = os.getenv('oneOne') -twoOne = os.getenv('twoOne') -twoTwo = os.getenv('twoTwo') -threeOne = os.getenv('threeOne') -threeTwo = os.getenv('threeTwo') -threeThree = os.getenv('threeThree') - -if oneOne != '11': - print('error: oneOne is not 11') - exit(1) - -if twoOne != '21': - print('error: twoOne is not 21') - exit(1) - -if twoTwo != '22': - print('error: twoTwo is not 22') - exit(1) - -if threeOne is not None: - print('error: threeOne is not None') - exit(1) - -if threeTwo is not None: - print('error: threeTwo is not None') - exit(1) - -if threeThree is not None: - print('error: threeThree is not None') - exit(1) - -print('good') - ---- -name: twoOne -tools: threeOne, threeTwo -sharecred: getcred with twoOne as var and 21 as val - -#!python3 - -import os - -oneOne = os.getenv('oneOne') -twoOne = os.getenv('twoOne') -twoTwo = os.getenv('twoTwo') -threeOne = os.getenv('threeOne') -threeTwo = os.getenv('threeTwo') -threeThree = os.getenv('threeThree') - -if oneOne is not None: - print('error: oneOne is not None') - exit(1) - -if twoOne is not None: - print('error: twoOne is not None') - exit(1) - -if twoTwo is not None: - print('error: twoTwo is not None') - exit(1) - -if threeOne != '31': - print('error: threeOne is not 31') - exit(1) - -if threeTwo != '32': - print('error: threeTwo is not 32') - exit(1) - -if threeThree is not None: - print('error: threeThree is not None') - exit(1) - -print('good') - ---- -name: twoTwo -tools: threeTwo, threeThree -sharecred: getcred with twoTwo as var and 22 as val - -#!python3 - -import os - -oneOne = os.getenv('oneOne') -twoOne = os.getenv('twoOne') -twoTwo = os.getenv('twoTwo') -threeOne = os.getenv('threeOne') -threeTwo = os.getenv('threeTwo') -threeThree = os.getenv('threeThree') - -if oneOne is not None: - print('error: oneOne is not None') - exit(1) - -if twoOne is not None: - print('error: twoOne is not None') - exit(1) - -if twoTwo is not None: - print('error: twoTwo is not None') - exit(1) - -if threeOne is not None: - print('error: threeOne is not None') - exit(1) - -if threeTwo != '32': - print('error: threeTwo is not 32') - exit(1) - -if threeThree != '33': - print('error: threeThree is not 33') - exit(1) - -print('good') - ---- -name: threeOne -sharecred: getcred with threeOne as var and 31 as val - ---- -name: threeTwo -sharecred: getcred with threeTwo as var and 32 as val - ---- -name: threeThree -sharecred: getcred with threeThree as var and 33 as val - ---- -name: getcred - -#!python3 - -import os -import json - -var = os.getenv('VAR') -val = os.getenv('VAL') - -output = { - "env": { - var: val - } -} -print(json.dumps(output)) diff --git a/pkg/config/cliconfig.go b/pkg/config/cliconfig.go index 7a82b58a..d0ef00c8 100644 --- a/pkg/config/cliconfig.go +++ b/pkg/config/cliconfig.go @@ -73,7 +73,6 @@ func (a *AuthConfig) UnmarshalJSON(data []byte) error { type CLIConfig struct { Auths map[string]AuthConfig `json:"auths,omitempty"` CredentialsStore string `json:"credsStore,omitempty"` - GatewayURL string `json:"gatewayURL,omitempty"` Integrations map[string]string `json:"integrations,omitempty"` auths map[string]types.AuthConfig diff --git a/pkg/credentials/store.go b/pkg/credentials/store.go index 9827b147..1843cd8d 100644 --- a/pkg/credentials/store.go +++ b/pkg/credentials/store.go @@ -225,7 +225,7 @@ func validateCredentialCtx(ctxs []string) error { } // check alphanumeric - r := regexp.MustCompile("^[a-zA-Z0-9]+$") + r := regexp.MustCompile("^[-a-zA-Z0-9]+$") for _, c := range ctxs { if !r.MatchString(c) { return fmt.Errorf("credential contexts must be alphanumeric") diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 956822dd..e6113d86 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -1,7 +1,6 @@ package parser import ( - "bufio" "fmt" "io" "maps" @@ -17,8 +16,10 @@ import ( var ( sepRegex = regexp.MustCompile(`^\s*---+\s*$`) + endHeaderRegex = regexp.MustCompile(`^\s*===+\s*$`) strictSepRegex = regexp.MustCompile(`^---\n$`) skipRegex = regexp.MustCompile(`^![-.:*\w]+\s*$`) + nameRegex = regexp.MustCompile(`^[a-z]+$`) ) func normalize(key string) string { @@ -74,7 +75,7 @@ func addArg(line string, tool *types.Tool) error { return nil } -func isParam(line string, tool *types.Tool) (_ bool, err error) { +func isParam(line string, tool *types.Tool, scan *simplescanner) (_ bool, err error) { key, value, ok := strings.Cut(line, ":") if !ok { return false, nil @@ -90,7 +91,7 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { case "globalmodel", "globalmodelname": tool.Parameters.GlobalModelName = value case "description": - tool.Parameters.Description = value + tool.Parameters.Description = scan.AddMultiline(value) case "internalprompt": v, err := toBool(value) if err != nil { @@ -104,27 +105,33 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { } tool.Parameters.Chat = v case "export", "exporttool", "exports", "exporttools", "sharetool", "sharetools", "sharedtool", "sharedtools": - tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...) + tool.Parameters.Export = append(tool.Parameters.Export, csv(scan.AddMultiline(value))...) case "tool", "tools": - tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...) + tool.Parameters.Tools = append(tool.Parameters.Tools, csv(scan.AddMultiline(value))...) case "inputfilter", "inputfilters": - tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...) + tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(scan.AddMultiline(value))...) case "shareinputfilter", "shareinputfilters", "sharedinputfilter", "sharedinputfilters": - tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...) + tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(scan.AddMultiline(value))...) case "outputfilter", "outputfilters": - tool.Parameters.OutputFilters = append(tool.Parameters.OutputFilters, csv(value)...) + tool.Parameters.OutputFilters = append(tool.Parameters.OutputFilters, csv(scan.AddMultiline(value))...) case "shareoutputfilter", "shareoutputfilters", "sharedoutputfilter", "sharedoutputfilters": - tool.Parameters.ExportOutputFilters = append(tool.Parameters.ExportOutputFilters, csv(value)...) + tool.Parameters.ExportOutputFilters = append(tool.Parameters.ExportOutputFilters, csv(scan.AddMultiline(value))...) case "agent", "agents": - tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...) + tool.Parameters.Agents = append(tool.Parameters.Agents, csv(scan.AddMultiline(value))...) case "globaltool", "globaltools": - tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(value)...) + tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(scan.AddMultiline(value))...) case "exportcontext", "exportcontexts", "sharecontext", "sharecontexts", "sharedcontext", "sharedcontexts": - tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(value)...) + tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(scan.AddMultiline(value))...) case "context": - tool.Parameters.Context = append(tool.Parameters.Context, csv(value)...) + tool.Parameters.Context = append(tool.Parameters.Context, csv(scan.AddMultiline(value))...) + case "metadata": + mkey, mvalue, _ := strings.Cut(scan.AddMultiline(value), ":") + if tool.MetaData == nil { + tool.MetaData = map[string]string{} + } + tool.MetaData[strings.TrimSpace(mkey)] = strings.TrimSpace(mvalue) case "args", "arg", "param", "params", "parameters", "parameter": - if err := addArg(value, tool); err != nil { + if err := addArg(scan.AddMultiline(value), tool); err != nil { return false, err } case "maxtoken", "maxtokens": @@ -149,13 +156,13 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { return false, err } case "credentials", "creds", "credential", "cred": - tool.Parameters.Credentials = append(tool.Parameters.Credentials, value) + tool.Parameters.Credentials = append(tool.Parameters.Credentials, csv(scan.AddMultiline(value))...) case "sharecredentials", "sharecreds", "sharecredential", "sharecred", "sharedcredentials", "sharedcreds", "sharedcredential", "sharedcred": - tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value) + tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, scan.AddMultiline(value)) case "type": tool.Type = types.ToolType(strings.ToLower(value)) default: - return false, nil + return nameRegex.MatchString(key), nil } return true, nil @@ -206,6 +213,7 @@ func (c *context) finish(tools *[]Node) { len(c.tool.ExportInputFilters) > 0 || len(c.tool.ExportOutputFilters) > 0 || len(c.tool.Agents) > 0 || + len(c.tool.ExportCredentials) > 0 || c.tool.Chat { *tools = append(*tools, Node{ ToolNode: &ToolNode{ @@ -391,7 +399,10 @@ func assignMetadata(nodes []Node) (result []Node) { for _, node := range nodes { if node.ToolNode != nil { - node.ToolNode.Tool.MetaData = metadata[node.ToolNode.Tool.Name] + if node.ToolNode.Tool.MetaData == nil { + node.ToolNode.Tool.MetaData = map[string]string{} + } + maps.Copy(node.ToolNode.Tool.MetaData, metadata[node.ToolNode.Tool.Name]) for wildcard := range metadata { if strings.Contains(wildcard, "*") { if m, err := path.Match(wildcard, node.ToolNode.Tool.Name); m && err == nil { @@ -433,15 +444,71 @@ func isGPTScriptHashBang(line string) bool { return false } -func parse(input io.Reader) ([]Node, error) { - scan := bufio.NewScanner(input) +type simplescanner struct { + lines []string +} + +func newSimpleScanner(data []byte) *simplescanner { + if len(data) == 0 { + return &simplescanner{} + } + lines := strings.Split(string(data), "\n") + return &simplescanner{ + lines: append([]string{""}, lines...), + } +} + +func dropCR(s string) string { + if len(s) > 0 && s[len(s)-1] == '\r' { + return s[:len(s)-1] + } + return s +} +func (s *simplescanner) AddMultiline(current string) string { + result := current + for { + if len(s.lines) < 2 || len(s.lines[1]) == 0 { + return result + } + if strings.HasPrefix(s.lines[1], " ") || strings.HasPrefix(s.lines[1], "\t") { + result += " " + dropCR(s.lines[1]) + s.lines = s.lines[1:] + } else { + return result + } + } +} + +func (s *simplescanner) Text() string { + if len(s.lines) == 0 { + return "" + } + return dropCR(s.lines[0]) +} + +func (s *simplescanner) Scan() bool { + if len(s.lines) == 0 { + return false + } + s.lines = s.lines[1:] + return true +} + +func parse(input io.Reader) ([]Node, error) { var ( tools []Node context context lineNo int ) + data, err := io.ReadAll(input) + if err != nil { + return nil, err + } + + scan := newSimpleScanner(data) + for scan.Scan() { lineNo++ if context.tool.Source.LineNo == 0 { @@ -488,11 +555,15 @@ func parse(input io.Reader) ([]Node, error) { } // Look for params - if isParam, err := isParam(line, &context.tool); err != nil { + if isParam, err := isParam(line, &context.tool, scan); err != nil { return nil, NewErrLine("", lineNo, err) } else if isParam { context.seenParam = true continue + } else if endHeaderRegex.MatchString(line) { + // force the end of the header and don't include the current line in the header + context.inBody = true + continue } } diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index f98b74e2..6eab45c9 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -1,6 +1,7 @@ package parser import ( + "reflect" "strings" "testing" @@ -244,6 +245,7 @@ share output filters: shared func TestParseMetaData(t *testing.T) { input := ` name: first +metadata: foo: bar body --- @@ -269,8 +271,89 @@ foo bar assert.Len(t, tools, 1) autogold.Expect(map[string]string{ + "foo": "bar", "package.json": "foo=base\nf", "requirements.txt": "asdf", "other": "foo bar", }).Equal(t, tools[0].MetaData) + + autogold.Expect(`Name: first +Meta Data: foo: bar +Meta Data: other: foo bar +Meta Data: requirements.txt: asdf + +body +--- +!metadata:first:package.json +foo=base +f +`).Equal(t, tools[0].String()) +} + +func TestFormatWithBadInstruction(t *testing.T) { + input := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "foo", + }, + Instructions: "foo: bar", + }, + } + autogold.Expect("Name: foo\n===\nfoo: bar\n").Equal(t, input.String()) + + tools, err := ParseTools(strings.NewReader(input.String())) + require.NoError(t, err) + if reflect.DeepEqual(input, tools[0]) { + t.Errorf("expected %v, got %v", input, tools[0]) + } +} + +func TestSingleTool(t *testing.T) { + input := ` +name: foo + +#!sys.echo +hi +` + + tools, err := ParseTools(strings.NewReader(input)) + require.NoError(t, err) + autogold.Expect(types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{Name: "foo"}, + Instructions: "#!sys.echo\nhi", + }, + Source: types.ToolSource{LineNo: 1}, + }).Equal(t, tools[0]) +} + +func TestMultiline(t *testing.T) { + input := ` +name: first +credential: foo + , bar, + baz +model: the model + +body +` + tools, err := ParseTools(strings.NewReader(input)) + require.NoError(t, err) + + assert.Len(t, tools, 1) + autogold.Expect(types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "first", + ModelName: "the model", + Credentials: []string{ + "foo", + "bar", + "baz", + }, + }, + Instructions: "body", + }, + Source: types.ToolSource{LineNo: 1}, + }).Equal(t, tools[0]) } diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index 93899c84..8dbd2ba7 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -79,3 +79,35 @@ echo ${FOO}:${INPUT} resp, err = r.Chat(context.Background(), nil, prg, nil, `"foo":"123"}`) r.AssertStep(t, resp, err) } + +func TestShareCreds(t *testing.T) { + r := tester.NewRunner(t) + prg, err := loader.ProgramFromSource(context.Background(), ` +creds: foo + +#!/bin/bash +echo $CRED +echo $CRED2 + +--- +name: foo +share credentials: bar + +--- +name: bar +share credentials: baz + +#!/bin/bash +echo '{"env": {"CRED": "that worked"}}' + +--- +name: baz + +#!/bin/bash +echo '{"env": {"CRED2": "that also worked"}}' +`, "") + require.NoError(t, err) + + resp, err := r.Chat(context.Background(), nil, prg, nil, "") + r.AssertStep(t, resp, err) +} diff --git a/pkg/tests/testdata/TestShareCreds/step1.golden b/pkg/tests/testdata/TestShareCreds/step1.golden new file mode 100644 index 00000000..9d584f92 --- /dev/null +++ b/pkg/tests/testdata/TestShareCreds/step1.golden @@ -0,0 +1,6 @@ +`{ + "done": true, + "content": "that worked\nthat also worked\n", + "toolID": "", + "state": null +}` diff --git a/pkg/types/completion.go b/pkg/types/completion.go index 5b3899c3..6a05effa 100644 --- a/pkg/types/completion.go +++ b/pkg/types/completion.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" - "github.com/fatih/color" "github.com/getkin/kin-openapi/openapi3" ) @@ -112,7 +111,7 @@ func (c CompletionMessage) String() string { } buf.WriteString(content.Text) if content.ToolCall != nil { - buf.WriteString(fmt.Sprintf(" %s -> %s", color.GreenString(content.ToolCall.Function.Name), content.ToolCall.Function.Arguments)) + buf.WriteString(fmt.Sprintf(" %s -> %s", content.ToolCall.Function.Name, content.ToolCall.Function.Arguments)) } } return buf.String() diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 0bd7bc02..cefbd311 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -476,9 +476,22 @@ func (t ToolDef) String() string { _, _ = fmt.Fprintf(buf, "Chat: true\n") } + keys := maps.Keys(t.MetaData) + sort.Strings(keys) + for _, key := range keys { + value := t.MetaData[key] + if !strings.Contains(value, "\n") { + _, _ = fmt.Fprintf(buf, "Meta Data: %s: %s\n", key, value) + } + } + // Instructions should be printed last if t.Instructions != "" && t.BuiltinFunc == nil { - _, _ = fmt.Fprintln(buf) + if strings.Contains(strings.Split(strings.TrimSpace(t.Instructions), "\n")[0], ":") { + _, _ = fmt.Fprintln(buf, "===") + } else { + _, _ = fmt.Fprintln(buf) + } _, _ = fmt.Fprintln(buf, t.Instructions) } @@ -486,14 +499,17 @@ func (t ToolDef) String() string { keys := maps.Keys(t.MetaData) sort.Strings(keys) for _, key := range keys { - buf.WriteString("---\n") - buf.WriteString("!metadata:") - buf.WriteString(t.Name) - buf.WriteString(":") - buf.WriteString(key) - buf.WriteString("\n") - buf.WriteString(t.MetaData[key]) - buf.WriteString("\n") + value := t.MetaData[key] + if strings.Contains(value, "\n") { + buf.WriteString("---\n") + buf.WriteString("!metadata:") + buf.WriteString(t.Name) + buf.WriteString(":") + buf.WriteString(key) + buf.WriteString("\n") + buf.WriteString(t.MetaData[key]) + buf.WriteString("\n") + } } } @@ -512,6 +528,56 @@ func (t Tool) GetNextAgentGroup(prg *Program, agentGroup []ToolReference, toolID return agentGroup, nil } +func (t Tool) getCredentials(prg *Program) (result []ToolReference, _ error) { + toolRefs, err := t.GetToolRefsFromNames(t.Credentials) + if err != nil { + return nil, err + } + + for _, toolRef := range toolRefs { + tool, ok := prg.ToolSet[toolRef.ToolID] + if !ok { + continue + } + + if !tool.IsNoop() { + result = append(result, toolRef) + } + + shared, err := tool.getSharedCredentials(prg) + if err != nil { + return nil, err + } + result = append(result, shared...) + } + + return result, nil +} + +func (t Tool) getSharedCredentials(prg *Program) (result []ToolReference, _ error) { + toolRefs, err := t.GetToolRefsFromNames(t.ExportCredentials) + if err != nil { + return nil, err + } + for _, toolRef := range toolRefs { + tool, ok := prg.ToolSet[toolRef.ToolID] + if !ok { + continue + } + + if !tool.IsNoop() { + result = append(result, toolRef) + } + + nested, err := tool.getSharedCredentials(prg) + if err != nil { + return nil, err + } + result = append(result, nested...) + } + return result, nil +} + func (t Tool) getAgents(prg *Program) (result []ToolReference, _ error) { toolRefs, err := t.GetToolRefsFromNames(t.Agents) if err != nil { @@ -542,6 +608,9 @@ func (t Tool) GetToolsByType(prg *Program, toolType ToolType) ([]ToolReference, if toolType == ToolTypeAgent { // Agents are special, they can only be sourced from direct references and not the generic 'tool:' or shared by references return t.getAgents(prg) + } else if toolType == ToolTypeCredential { + // Credentials are special too, you can only get shared credentials from directly referenced credentials + return t.getCredentials(prg) } toolSet := &toolRefSet{} @@ -560,8 +629,6 @@ func (t Tool) GetToolsByType(prg *Program, toolType ToolType) ([]ToolReference, directRefs = t.InputFilters case ToolTypeTool: toolsListFilterType = append(toolsListFilterType, ToolTypeDefault, ToolTypeAgent) - case ToolTypeCredential: - directRefs = t.Credentials default: return nil, fmt.Errorf("unknown tool type %v", toolType) } @@ -602,8 +669,6 @@ func (t Tool) GetToolsByType(prg *Program, toolType ToolType) ([]ToolReference, case ToolTypeInput: exportRefs = tool.ExportInputFilters case ToolTypeTool: - case ToolTypeCredential: - exportRefs = tool.ExportCredentials default: return nil, fmt.Errorf("unknown tool type %v", toolType) } diff --git a/pkg/types/tool_test.go b/pkg/types/tool_test.go index e95c2248..89c36ac8 100644 --- a/pkg/types/tool_test.go +++ b/pkg/types/tool_test.go @@ -73,6 +73,7 @@ Credential: Credential2 Share Credential: ExportCredential1 Share Credential: ExportCredential2 Chat: true +Meta Data: requirements.txt: requests=5 This is a sample instruction --- @@ -81,9 +82,6 @@ This is a sample instruction // blah blah some ugly JSON } ---- -!metadata:Tool Sample:requirements.txt -requests=5 `).Equal(t, tool.String()) }