Skip to content

Commit 53bf7cd

Browse files
authored
enhance: send MCP errors back to the LLM so it can correct if possible (#974)
Signed-off-by: Donnie Adams <[email protected]>
1 parent 20f384d commit 53bf7cd

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

pkg/engine/cmd.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func compressEnv(envs []string) (result []string) {
6565
return
6666
}
6767

68-
func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) {
68+
func (e *Engine) runCommand(ctx Context, tool types.Tool, input string) (cmdOut string, cmdErr error) {
6969
id := counter.Next()
7070

7171
var combinedOutput string
@@ -128,7 +128,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
128128

129129
cmd, stop, err := e.newCommand(commandCtx, extraEnv, tool, input, true)
130130
if err != nil {
131-
if toolCategory == NoCategory && ctx.Parent != nil {
131+
if ctx.ToolCategory == NoCategory && ctx.Parent != nil {
132132
return fmt.Sprintf("ERROR: got (%v) while parsing command", err), nil
133133
}
134134
return "", fmt.Errorf("got (%v) while parsing command", err)
@@ -167,7 +167,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
167167

168168
if err := cmd.Run(); err != nil && (commandCtx.Err() == nil || ctx.Ctx.Err() != nil) {
169169
// If the command failed and the context hasn't been canceled, then return the error.
170-
if toolCategory == NoCategory && ctx.Parent != nil {
170+
if ctx.ToolCategory == NoCategory && ctx.Parent != nil {
171171
// If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry.
172172
return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, stdoutAndErr), nil
173173
}

pkg/engine/engine.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ type Engine struct {
4545
}
4646

4747
type MCPRunner interface {
48-
Run(ctx context.Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error)
48+
Run(ctx Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error)
4949
}
5050

5151
type State struct {
@@ -313,7 +313,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too
313313
}
314314

315315
func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) {
316-
output, err := e.MCPRunner.Run(ctx.Ctx, e.Progress, tool, input)
316+
output, err := e.MCPRunner.Run(ctx, e.Progress, tool, input)
317317
if err != nil {
318318
return nil, fmt.Errorf("failed to run MCP invoke: %w", err)
319319
}
@@ -335,7 +335,7 @@ func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*R
335335
} else if tool.IsCall() {
336336
return e.runCall(ctx, tool, input)
337337
}
338-
s, err := e.runCommand(ctx, tool, input, ctx.ToolCategory)
338+
s, err := e.runCommand(ctx, tool, input)
339339
return &Return{
340340
Result: &s,
341341
}, err

pkg/mcp/runner.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
package mcp
22

33
import (
4-
"context"
54
"encoding/json"
65
"fmt"
76
"strings"
87

8+
"github.com/gptscript-ai/gptscript/pkg/engine"
99
"github.com/gptscript-ai/gptscript/pkg/types"
1010
"github.com/mark3labs/mcp-go/mcp"
1111
)
1212

13-
func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) {
13+
func (l *Local) Run(ctx engine.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) {
1414
fields := strings.Fields(tool.Instructions)
1515
if len(fields) < 2 {
1616
return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions)
@@ -41,8 +41,16 @@ func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool t
4141
request.Params.Name = toolName
4242
request.Params.Arguments = arguments
4343

44-
result, err := session.Client.CallTool(ctx, request)
44+
result, err := session.Client.CallTool(ctx.Ctx, request)
4545
if err != nil {
46+
if ctx.ToolCategory == engine.NoCategory && ctx.Parent != nil {
47+
var output []byte
48+
if result != nil {
49+
output, _ = json.Marshal(result)
50+
}
51+
// If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry.
52+
return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, string(output)), nil
53+
}
4654
return "", fmt.Errorf("failed to call tool %s: %w", toolName, err)
4755
}
4856

0 commit comments

Comments
 (0)