Skip to content

Commit ba1aba4

Browse files
committed
feat: add prompt support
Signed-off-by: Donnie Adams <[email protected]>
1 parent b9cbffb commit ba1aba4

File tree

4 files changed

+137
-32
lines changed

4 files changed

+137
-32
lines changed

client.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@ const relativeToBinaryPath = "<me>"
2727
type Client interface {
2828
Run(context.Context, string, Options) (*Run, error)
2929
Evaluate(context.Context, Options, ...fmt.Stringer) (*Run, error)
30-
Parse(ctx context.Context, fileName string) ([]Node, error)
31-
ParseTool(ctx context.Context, toolDef string) ([]Node, error)
32-
Version(ctx context.Context) (string, error)
33-
Fmt(ctx context.Context, nodes []Node) (string, error)
34-
ListTools(ctx context.Context) (string, error)
35-
ListModels(ctx context.Context) ([]string, error)
36-
Confirm(ctx context.Context, resp AuthResponse) error
30+
Parse(context.Context, string) ([]Node, error)
31+
ParseTool(context.Context, string) ([]Node, error)
32+
Version(context.Context) (string, error)
33+
Fmt(context.Context, []Node) (string, error)
34+
ListTools(context.Context) (string, error)
35+
ListModels(context.Context) ([]string, error)
36+
Confirm(context.Context, AuthResponse) error
37+
PromptResponse(context.Context, PromptResponse) error
3738
Close()
3839
}
3940

@@ -208,6 +209,11 @@ func (c *client) Confirm(ctx context.Context, resp AuthResponse) error {
208209
return err
209210
}
210211

212+
func (c *client) PromptResponse(ctx context.Context, resp PromptResponse) error {
213+
_, err := c.runBasicCommand(ctx, "prompt-response/"+resp.RunID, resp.Response)
214+
return err
215+
}
216+
211217
func (c *client) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) {
212218
run := &Run{
213219
url: c.gptscriptURL,

client_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,91 @@ func TestConfirmDeny(t *testing.T) {
808808
}
809809
}
810810

811+
func TestPrompt(t *testing.T) {
812+
var eventContent string
813+
tools := []fmt.Stringer{
814+
&ToolDef{
815+
Instructions: "Use the sys.prompt user to ask the user for 'first name' which is not sensitive. After you get their first name, say hello.",
816+
Tools: []string{"sys.prompt"},
817+
},
818+
}
819+
820+
run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tools...)
821+
if err != nil {
822+
t.Errorf("Error executing tool: %v", err)
823+
}
824+
825+
// Wait for the prompt event
826+
var promptFrame *PromptFrame
827+
for e := range run.Events() {
828+
if e.Call != nil {
829+
for _, o := range e.Call.Output {
830+
eventContent += o.Content
831+
}
832+
833+
}
834+
if e.Prompt != nil {
835+
if e.Prompt.Type == EventTypePrompt {
836+
promptFrame = e.Prompt
837+
break
838+
}
839+
}
840+
}
841+
842+
if promptFrame == nil {
843+
t.Fatalf("No prompt call event")
844+
}
845+
846+
if promptFrame.Sensitive {
847+
t.Errorf("Unexpected sensitive prompt event: %v", promptFrame.Sensitive)
848+
}
849+
850+
if !strings.Contains(promptFrame.Message, "first name") {
851+
t.Errorf("unexpected confirm input: %s", promptFrame.Message)
852+
}
853+
854+
if len(promptFrame.Fields) != 1 {
855+
t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields))
856+
}
857+
858+
if promptFrame.Fields[0] != "first name" {
859+
t.Errorf("Unexpected field: %s", promptFrame.Fields[0])
860+
}
861+
862+
if err = c.PromptResponse(context.Background(), PromptResponse{
863+
RunID: promptFrame.RunID,
864+
Response: map[string]string{promptFrame.Fields[0]: "Clicky"},
865+
}); err != nil {
866+
t.Errorf("Error responding: %v", err)
867+
}
868+
869+
// Read the remainder of the events
870+
for e := range run.Events() {
871+
if e.Call != nil {
872+
for _, o := range e.Call.Output {
873+
eventContent += o.Content
874+
}
875+
}
876+
}
877+
878+
out, err := run.Text()
879+
if err != nil {
880+
t.Errorf("Error reading output: %v", err)
881+
}
882+
883+
if !strings.Contains(eventContent, "Clicky") {
884+
t.Errorf("Unexpected event output: %s", eventContent)
885+
}
886+
887+
if !strings.Contains(out, "Hello") || !strings.Contains(out, "Clicky") {
888+
t.Errorf("Unexpected output: %s", out)
889+
}
890+
891+
if len(run.ErrorOutput()) != 0 {
892+
t.Errorf("Should have no stderr output: %v", run.ErrorOutput())
893+
}
894+
}
895+
811896
func TestGetCommand(t *testing.T) {
812897
currentEnvVar := os.Getenv("GPTSCRIPT_BIN")
813898
t.Cleanup(func() {

frame.go

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
11
package gptscript
22

3-
import (
4-
"time"
3+
import "time"
4+
5+
type ToolCategory string
6+
7+
type EventType string
8+
9+
const (
10+
CredentialToolCategory ToolCategory = "credential"
11+
ContextToolCategory ToolCategory = "context"
12+
NoCategory ToolCategory = ""
13+
14+
EventTypeRunStart EventType = "runStart"
15+
EventTypeCallStart EventType = "callStart"
16+
EventTypeCallContinue EventType = "callContinue"
17+
EventTypeCallSubCalls EventType = "callSubCalls"
18+
EventTypeCallProgress EventType = "callProgress"
19+
EventTypeChat EventType = "callChat"
20+
EventTypeCallConfirm EventType = "callConfirm"
21+
EventTypeCallFinish EventType = "callFinish"
22+
EventTypeRunFinish EventType = "runFinish"
23+
24+
EventTypePrompt EventType = "prompt"
525
)
626

727
type Frame struct {
8-
Run *RunFrame `json:"run,omitempty"`
9-
Call *CallFrame `json:"call,omitempty"`
28+
Run *RunFrame `json:"run,omitempty"`
29+
Call *CallFrame `json:"call,omitempty"`
30+
Prompt *PromptFrame `json:"prompt,omitempty"`
1031
}
1132

1233
type RunFrame struct {
@@ -74,24 +95,11 @@ type InputContext struct {
7495
Content string `json:"content,omitempty"`
7596
}
7697

77-
type ToolCategory string
78-
79-
const (
80-
CredentialToolCategory ToolCategory = "credential"
81-
ContextToolCategory ToolCategory = "context"
82-
NoCategory ToolCategory = ""
83-
)
84-
85-
type EventType string
86-
87-
const (
88-
EventTypeRunStart EventType = "runStart"
89-
EventTypeCallStart EventType = "callStart"
90-
EventTypeCallContinue EventType = "callContinue"
91-
EventTypeCallSubCalls EventType = "callSubCalls"
92-
EventTypeCallProgress EventType = "callProgress"
93-
EventTypeChat EventType = "callChat"
94-
EventTypeCallConfirm EventType = "callConfirm"
95-
EventTypeCallFinish EventType = "callFinish"
96-
EventTypeRunFinish EventType = "runFinish"
97-
)
98+
type PromptFrame struct {
99+
RunID string `json:"runID,omitempty"`
100+
Type EventType `json:"type,omitempty"`
101+
Time time.Time `json:"time,omitempty"`
102+
Message string `json:"message,omitempty"`
103+
Fields []string `json:"fields,omitempty"`
104+
Sensitive bool `json:"sensitive,omitempty"`
105+
}

prompt.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package gptscript
2+
3+
type PromptResponse struct {
4+
RunID string `json:"runID,omitempty"`
5+
Response map[string]string `json:"response,omitempty"`
6+
}

0 commit comments

Comments
 (0)