diff --git a/go.mod b/go.mod index f803a3b9..35c9689e 100644 --- a/go.mod +++ b/go.mod @@ -18,10 +18,11 @@ require ( github.com/gptscript-ai/chat-completion-client v0.0.0-20250224164718-139cb4507b1d github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 - github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee + github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9 github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 + github.com/mark3labs/mcp-go v0.21.1 github.com/mholt/archives v0.1.0 github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 @@ -122,6 +123,7 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.5.4 // indirect github.com/yuin/goldmark-emoji v1.0.2 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect diff --git a/go.sum b/go.sum index 74341af5..95e6b1a7 100644 --- a/go.sum +++ b/go.sum @@ -203,8 +203,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7J github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= -github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee h1:70PHW6Xw70yNNZ5aX936XqcMLwNmfMZpCV3FCOGKpxE= -github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= +github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9 h1:wQC8sKyeGA50WnCEG+Jo5FNRIkuX3HX8d3ubyWCCoI8= +github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -270,6 +270,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.21.1 h1:7Ek6KPIIbMhEYHRiRIg6K6UAgNZCJaHKQp926MNr6V0= +github.com/mark3labs/mcp-go v0.21.1/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -406,6 +408,8 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.3.7/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.5.4 h1:2uY/xC0roWy8IBEGLgB1ywIoEJFGmRrX21YQcvGZzjU= diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 4b0642d2..b5a823b2 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -215,7 +215,7 @@ func (r *GPTScript) listTools(ctx context.Context, gptScript *gptscript.GPTScrip // Don't print instructions tool.Instructions = "" - lines = append(lines, tool.String()) + lines = append(lines, tool.Print()) } fmt.Println(strings.Join(lines, "\n---\n")) return nil diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index abf45e8c..39357e9a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/counter" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) @@ -41,6 +42,11 @@ type Engine struct { RuntimeManager RuntimeManager Env []string Progress chan<- types.CompletionStatus + MCPRunner MCPRunner +} + +type MCPRunner interface { + Run(ctx context.Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) } type State struct { @@ -307,6 +313,21 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too return nil } +func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) { + runner := e.MCPRunner + if runner == nil { + runner = mcp.DefaultRunner + } + output, err := runner.Run(ctx.Ctx, e.Progress, tool, input) + if err != nil { + return nil, fmt.Errorf("failed to run MCP invoke: %w", err) + } + + return &Return{ + Result: &output, + }, nil +} + func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*Return, error) { if tool.IsHTTP() { return e.runHTTP(ctx, tool, input) @@ -342,6 +363,10 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, err error) { } }() + if tool.IsMCPInvoke() { + return e.runMCPInvoke(ctx, tool, input) + } + if tool.IsCommand() { return e.runCommandTools(ctx, tool, input) } @@ -378,6 +403,7 @@ func addUpdateSystem(ctx Context, tool types.Tool, msgs []types.CompletionMessag instructions = append(instructions, context.Content) } + tool.Instructions = strings.TrimPrefix(tool.Instructions, types.PromptPrefix) if tool.Instructions != "" { instructions = append(instructions, tool.Instructions) } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index e70827c6..2a6f2433 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -20,6 +20,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/gptscript-ai/gptscript/pkg/parser" "github.com/gptscript-ai/gptscript/pkg/system" @@ -155,7 +156,23 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { return openAPIDocument } -func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { +func processMCP(ctx context.Context, tool []types.Tool, mcpLoader MCPLoader) (result []types.Tool, _ error) { + for _, t := range tool { + if t.IsMCP() { + mcpTools, err := mcpLoader.Load(ctx, t) + if err != nil { + return nil, fmt.Errorf("error loading MCP tools: %w", err) + } + result = append(result, mcpTools...) + } else { + result = append(result, t) + } + } + + return result, nil +} + +func readTool(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { data := base.Content var ( @@ -212,6 +229,11 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, fmt.Errorf("no tools found in %s", base) } + tools, err := processMCP(ctx, tools, mcp) + if err != nil { + return nil, err + } + var ( localTools = types.ToolSet{} targetTools []types.Tool @@ -279,17 +301,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base localTools[strings.ToLower(tool.Name)] = tool } - return linkAll(ctx, cache, prg, base, targetTools, localTools, defaultModel) + return linkAll(ctx, cache, mcp, prg, base, targetTools, localTools, defaultModel) } -func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { +func linkAll(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { localToolsMapping := make(map[string]string, len(tools)) for _, localTool := range localTools { localToolsMapping[strings.ToLower(localTool.Name)] = localTool.ID } for _, tool := range tools { - tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping, defaultModel) + tool, err := link(ctx, cache, mcp, prg, base, tool, localTools, localToolsMapping, defaultModel) if err != nil { return nil, err } @@ -298,7 +320,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base return } -func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { +func link(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { if existing, ok := prg.ToolSet[tool.ID]; ok { return existing, nil } @@ -323,7 +345,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so linkedTool = existing } else { var err error - linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping, defaultModel) + linkedTool, err = link(ctx, cache, mcp, prg, base, localTool, localTools, localToolsMapping, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err) } @@ -333,7 +355,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so toolNames[targetToolName] = struct{}{} } else { toolName, subTool := types.SplitToolRef(targetToolName) - resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool, defaultModel) + resolvedTools, err := resolve(ctx, cache, mcp, prg, base, toolName, subTool, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err) } @@ -373,7 +395,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. prg := types.Program{ ToolSet: types.ToolSet{}, } - tools, err := readTool(ctx, opt.Cache, &prg, &source{ + tools, err := readTool(ctx, opt.Cache, opt.MCPLoader, &prg, &source{ Content: []byte(content), Path: locationPath, Name: locationName, @@ -390,6 +412,11 @@ type Options struct { Cache *cache.Client Location string DefaultModel string + MCPLoader MCPLoader +} + +type MCPLoader interface { + Load(ctx context.Context, tool types.Tool) ([]types.Tool, error) } func complete(opts ...Options) (result Options) { @@ -397,6 +424,7 @@ func complete(opts ...Options) (result Options) { result.Cache = types.FirstSet(opt.Cache, result.Cache) result.Location = types.FirstSet(opt.Location, result.Location) result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) + result.MCPLoader = types.FirstSet(opt.MCPLoader, result.MCPLoader) } if result.Location == "" { @@ -407,6 +435,10 @@ func complete(opts ...Options) (result Options) { result.DefaultModel = builtin.GetDefaultModel() } + if result.MCPLoader == nil { + result.MCPLoader = mcp.DefaultLoader + } + return } @@ -430,7 +462,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty Name: name, ToolSet: types.ToolSet{}, } - tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName, opt.DefaultModel) + tools, err := resolve(ctx, opt.Cache, opt.MCPLoader, &prg, &source{}, name, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -438,7 +470,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty return prg, nil } -func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { if subTool == "" { t, ok := builtin.DefaultModel(name, defaultModel) if ok { @@ -452,7 +484,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, err } - result, err := readTool(ctx, cache, prg, s, subTool, defaultModel) + result, err := readTool(ctx, cache, mcp, prg, s, subTool, defaultModel) if err != nil { return nil, err } diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go new file mode 100644 index 00000000..b4b33ba6 --- /dev/null +++ b/pkg/mcp/loader.go @@ -0,0 +1,264 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "slices" + "strings" + "sync" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/gptscript-ai/gptscript/pkg/version" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +var ( + DefaultLoader = &Local{} + DefaultRunner = DefaultLoader +) + +type Local struct { + nextID int64 + lock sync.Mutex + sessions map[string]*Session +} + +type Session struct { + ID string + InitResult *mcp.InitializeResult + Client client.MCPClient + Config ServerConfig +} + +type Config struct { + MCPServers map[string]ServerConfig `json:"mcpServers"` +} + +type ServerConfig struct { + DisableInstruction bool `json:"disableInstruction"` + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + Server string `json:"server"` + URL string `json:"url"` + BaseURL string `json:"baseURL,omitempty"` + Headers map[string]string `json:"headers"` +} + +func (s *ServerConfig) GetBaseURL() string { + if s.BaseURL != "" { + return s.BaseURL + } + if s.Server != "" { + return s.Server + } + return s.URL +} + +func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, _ error) { + if !tool.IsMCP() { + return []types.Tool{tool}, nil + } + + _, configData, _ := strings.Cut(tool.Instructions, "\n") + var servers Config + + if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &servers); err != nil { + return nil, fmt.Errorf("failed to parse MCP configuration: %w\n%s", err, configData) + } + + if len(servers.MCPServers) == 0 { + // Try to load just one server + var server ServerConfig + if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &server); err != nil { + return nil, fmt.Errorf("failed to parse single MCP server configuration: %w\n%s", err, configData) + } + if server.Command == "" && server.URL == "" && server.Server == "" { + return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) + } + servers.MCPServers = map[string]ServerConfig{ + "default": server, + } + } + + if len(servers.MCPServers) > 1 { + return nil, fmt.Errorf("only a single MCP server definition is support") + } + + for _, server := range slices.Sorted(maps.Keys(servers.MCPServers)) { + session, err := l.loadSession(ctx, servers.MCPServers[server]) + if err != nil { + return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err) + } + + return l.sessionToTools(ctx, session, tool.Name) + } + + // This should never happen, but just in case + return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) +} + +func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName string) ([]types.Tool, error) { + tools, err := session.Client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %w", err) + } + + toolDefs := []types.Tool{{ /* this is a placeholder for main tool */ }} + var toolNames []string + + for _, tool := range tools.Tools { + var schema openapi3.Schema + + schemaData, err := json.Marshal(tool.InputSchema) + if err != nil { + panic(err) + } + + if tool.Name == "" { + // I dunno, bad tool? + continue + } + + if err := json.Unmarshal(schemaData, &schema); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool input schema: %w", err) + } + + annotations, err := json.Marshal(tool.Annotations) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool annotations: %w", err) + } + + toolDef := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: tool.Name, + Description: tool.Description, + Arguments: &schema, + }, + Instructions: types.MCPInvokePrefix + "." + tool.Name + " " + session.ID + " " + tool.Name, + }, + } + + if string(annotations) != "{}" { + toolDef.MetaData = map[string]string{ + "mcp-tool-annotations": string(annotations), + } + } + + if tool.Annotations.Title != "" && !slices.Contains(strings.Fields(tool.Annotations.Title), "as") { + toolDef.Name = tool.Annotations.Title + " as " + tool.Name + } + + toolDefs = append(toolDefs, toolDef) + toolNames = append(toolNames, tool.Name) + } + + main := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: toolName, + Description: session.InitResult.ServerInfo.Name, + Export: toolNames, + }, + MetaData: map[string]string{ + "bundle": "true", + }, + }, + } + + if session.InitResult.Instructions != "" { + data, _ := json.Marshal(map[string]any{ + "tools": toolNames, + "instructions": session.InitResult.Instructions, + }) + toolDefs = append(toolDefs, types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: session.ID, + Type: "context", + }, + Instructions: types.EchoPrefix + "\n" + `# START MCP SERVER INFO: ` + session.InitResult.ServerInfo.Name + "\n" + + `You have available the following tools from an MCP Server that has provided the following additional instructions` + "\n" + + string(data) + "\n" + + `# END MCP SERVER INFO` + "\n", + }, + }) + + main.ExportContext = append(main.ExportContext, session.ID) + } + + toolDefs[0] = main + return toolDefs, nil +} + +func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, error) { + id := hash.Digest(server) + l.lock.Lock() + existing, ok := l.sessions[id] + l.lock.Unlock() + if ok { + return existing, nil + } + + var ( + c client.MCPClient + err error + ) + + if server.Command != "" { + env := make([]string, 0, len(server.Env)) + for k, v := range server.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + c, err = client.NewStdioMCPClient(server.Command, env, server.Args...) + if err != nil { + return nil, fmt.Errorf("failed to create MCP stdio client: %w", err) + } + } else { + url := server.URL + if url == "" { + url = server.Server + } + c, err = client.NewSSEMCPClient(url, client.WithHeaders(server.Headers)) + if err != nil { + return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err) + } + } + + var initRequest mcp.InitializeRequest + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: version.ProgramName, + Version: version.Get().String(), + } + + initResult, err := c.Initialize(ctx, initRequest) + if err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + result := &Session{ + ID: id, + InitResult: initResult, + Client: c, + Config: server, + } + + l.lock.Lock() + defer l.lock.Unlock() + + if existing, ok := l.sessions[id]; ok { + return existing, c.Close() + } + + if l.sessions == nil { + l.sessions = make(map[string]*Session) + } + l.sessions[id] = result + return result, nil +} diff --git a/pkg/mcp/runner.go b/pkg/mcp/runner.go new file mode 100644 index 00000000..b6d5f584 --- /dev/null +++ b/pkg/mcp/runner.go @@ -0,0 +1,51 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/mark3labs/mcp-go/mcp" +) + +func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) { + fields := strings.Fields(tool.Instructions) + if len(fields) < 3 { + return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions) + } + + id := fields[1] + toolName := fields[2] + arguments := map[string]any{} + + if input != "" { + if err := json.Unmarshal([]byte(input), &arguments); err != nil { + return "", fmt.Errorf("failed to unmarshal input: %w", err) + } + } + + l.lock.Lock() + session, ok := l.sessions[id] + l.lock.Unlock() + if !ok { + return "", fmt.Errorf("session not found for MCP server %s", id) + } + + request := mcp.CallToolRequest{} + request.Params.Name = toolName + request.Params.Arguments = arguments + + result, err := session.Client.CallTool(ctx, request) + if err != nil { + return "", fmt.Errorf("failed to call tool %s: %w", toolName, err) + } + + str, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + + return string(str), nil +} diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index f5de8e10..3c4264a4 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -8,6 +8,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/tests/tester" + "github.com/gptscript-ai/gptscript/pkg/types" "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" ) @@ -203,3 +204,354 @@ echo "${GPTSCRIPT_INPUT}" require.NoError(t, err) autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data) } + +func TestMCPLoad(t *testing.T) { + r := tester.NewRunner(t) + prg, err := loader.ProgramFromSource(context.Background(), ` +name: mcp + +#!mcp + +{ + "mcpServers": { + "sqlite": { + "command": "docker", + "args": [ + "run", + "--rm", + "-i", + "-v", + "mcp-test:/mcp", + "mcp/sqlite@sha256:007ccae941a6f6db15b26ee41d92edda50ce157176d9273449e8b3f51d979c70", + "--db-path", + "/mcp/test.db" + ] + } + } +} +`, "") + require.NoError(t, err) + + autogold.Expect(types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "mcp", + Description: "sqlite", + ModelName: "gpt-4o", + Export: []string{ + "read_query", + "write_query", + "create_table", + "list_tables", + "describe_table", + "append_insight", + }, + }, + MetaData: map[string]string{"bundle": "true"}, + }, + ID: "inline:mcp", + ToolMapping: map[string][]types.ToolReference{ + "append_insight": {{ + Reference: "append_insight", + ToolID: "inline:append_insight", + }}, + "create_table": {{ + Reference: "create_table", + ToolID: "inline:create_table", + }}, + "describe_table": {{ + Reference: "describe_table", + ToolID: "inline:describe_table", + }}, + "list_tables": {{ + Reference: "list_tables", + ToolID: "inline:list_tables", + }}, + "read_query": {{ + Reference: "read_query", + ToolID: "inline:read_query", + }}, + "write_query": {{ + Reference: "write_query", + ToolID: "inline:write_query", + }}, + }, + LocalTools: map[string]string{ + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query", + }, + Source: types.ToolSource{Location: "inline"}, + WorkingDir: ".", + }).Equal(t, prg.ToolSet[prg.EntryToolID]) + autogold.Expect(7).Equal(t, len(prg.ToolSet[prg.EntryToolID].LocalTools)) + data, _ := json.MarshalIndent(prg.ToolSet, "", " ") + autogold.Expect(`{ + "inline:append_insight": { + "name": "append_insight", + "description": "Add a business insight to the memo", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "insight": { + "description": "Business insight discovered from data analysis", + "type": "string" + } + }, + "required": [ + "insight" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f append_insight", + "id": "inline:append_insight", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:create_table": { + "name": "create_table", + "description": "Create a new table in the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "CREATE TABLE SQL statement", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f create_table", + "id": "inline:create_table", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:describe_table": { + "name": "describe_table", + "description": "Get the schema information for a specific table", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "table_name": { + "description": "Name of the table to describe", + "type": "string" + } + }, + "required": [ + "table_name" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f describe_table", + "id": "inline:describe_table", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:list_tables": { + "name": "list_tables", + "description": "List all tables in the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f list_tables", + "id": "inline:list_tables", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:mcp": { + "name": "mcp", + "description": "sqlite", + "modelName": "gpt-4o", + "internalPrompt": null, + "export": [ + "read_query", + "write_query", + "create_table", + "list_tables", + "describe_table", + "append_insight" + ], + "metaData": { + "bundle": "true" + }, + "id": "inline:mcp", + "toolMapping": { + "append_insight": [ + { + "reference": "append_insight", + "toolID": "inline:append_insight" + } + ], + "create_table": [ + { + "reference": "create_table", + "toolID": "inline:create_table" + } + ], + "describe_table": [ + { + "reference": "describe_table", + "toolID": "inline:describe_table" + } + ], + "list_tables": [ + { + "reference": "list_tables", + "toolID": "inline:list_tables" + } + ], + "read_query": [ + { + "reference": "read_query", + "toolID": "inline:read_query" + } + ], + "write_query": [ + { + "reference": "write_query", + "toolID": "inline:write_query" + } + ] + }, + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:read_query": { + "name": "read_query", + "description": "Execute a SELECT query on the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "SELECT SQL query to execute", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f read_query", + "id": "inline:read_query", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:write_query": { + "name": "write_query", + "description": "Execute an INSERT, UPDATE, or DELETE query on the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "SQL query to execute", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke 441826308787ad271e84a381e90d8eccc3fce0fe94503636e679bd0984c79f2f write_query", + "id": "inline:write_query", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + } +}`).Equal(t, string(data)) + + prg.EntryToolID = prg.ToolSet[prg.EntryToolID].LocalTools["read_query"] + resp, err := r.Chat(context.Background(), nil, prg, nil, `{"query": "SELECT 1"}`, runner.RunOptions{}) + r.AssertStep(t, resp, err) +} diff --git a/pkg/tests/testdata/TestMCPLoad/call1-resp.golden b/pkg/tests/testdata/TestMCPLoad/call1-resp.golden new file mode 100644 index 00000000..2861a036 --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/call1-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestMCPLoad/call1.golden b/pkg/tests/testdata/TestMCPLoad/call1.golden new file mode 100644 index 00000000..31048a88 --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/call1.golden @@ -0,0 +1,3 @@ +`{ + "model": "gpt-4o" +}` diff --git a/pkg/tests/testdata/TestMCPLoad/step1.golden b/pkg/tests/testdata/TestMCPLoad/step1.golden new file mode 100644 index 00000000..ae20c8ed --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/step1.golden @@ -0,0 +1,6 @@ +`{ + "done": true, + "content": "{\"content\":[{\"type\":\"text\",\"text\":\"[{'1': 1}]\"}]}", + "toolID": "", + "state": null +}` diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 54780278..2b8498f4 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -16,11 +16,14 @@ import ( ) const ( - DaemonPrefix = "#!sys.daemon" - OpenAPIPrefix = "#!sys.openapi" - EchoPrefix = "#!sys.echo" - CallPrefix = "#!sys.call" - CommandPrefix = "#!" + DaemonPrefix = "#!sys.daemon" + OpenAPIPrefix = "#!sys.openapi" + EchoPrefix = "#!sys.echo" + CallPrefix = "#!sys.call" + MCPPrefix = "#!mcp" + MCPInvokePrefix = "#!sys.mcp.invoke" + CommandPrefix = "#!" + PromptPrefix = "!!" ) var ( @@ -862,6 +865,14 @@ func (t Tool) IsDaemon() bool { return strings.HasPrefix(t.Instructions, DaemonPrefix) } +func (t Tool) IsMCP() bool { + return strings.HasPrefix(t.Instructions, MCPPrefix) +} + +func (t Tool) IsMCPInvoke() bool { + return strings.HasPrefix(t.Instructions, MCPInvokePrefix) +} + func (t Tool) IsOpenAPI() bool { return strings.HasPrefix(t.Instructions, OpenAPIPrefix) } diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go index b5e0d1d5..fe9d7dde 100644 --- a/pkg/types/toolstring.go +++ b/pkg/types/toolstring.go @@ -44,6 +44,10 @@ func ToDisplayText(tool Tool, input string) string { } func ToSysDisplayString(id string, args map[string]string) (string, error) { + if suffix, ok := strings.CutPrefix(id, "sys.mcp.invoke."); ok { + return fmt.Sprintf("Invoking MCP `%s`", suffix), nil + } + switch id { case "sys.append": return fmt.Sprintf("Appending to file `%s`", args["filename"]), nil