Skip to content

Commit 468c4f1

Browse files
committed
enhance: add field-level sensitivity for prompts
Additionally, each field can now also have a description. This change is made such that all existing tools will work. However, existing code will need to be updated to support the new types. Signed-off-by: Donnie Adams <[email protected]>
1 parent 7ee5c80 commit 468c4f1

File tree

7 files changed

+227
-26
lines changed

7 files changed

+227
-26
lines changed

go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ require (
1717
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
1818
github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1
1919
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
20-
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e
21-
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6
20+
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61
21+
github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee
2222
github.com/hexops/autogold/v2 v2.2.1
2323
github.com/hexops/valast v1.4.4
2424
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056

go.sum

+4-4
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f
201201
github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
202202
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
203203
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
204-
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA=
205-
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
206-
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8=
207-
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4=
204+
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI=
205+
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
206+
github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee h1:70PHW6Xw70yNNZ5aX936XqcMLwNmfMZpCV3FCOGKpxE=
207+
github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY=
208208
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
209209
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
210210
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=

pkg/cli/gptscript.go

-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
494494
DisableCache: r.DisableCache,
495495
CredentialOverrides: r.CredentialOverride,
496496
Input: toolInput,
497-
CacheDir: r.CacheDir,
498497
SubTool: r.SubTool,
499498
Workspace: r.Workspace,
500499
SaveChatStateFile: r.SaveChatStateFile,

pkg/engine/call.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ func mergeInputs(base, overlay string) (string, error) {
7676
return base, nil
7777
}
7878

79-
err := json.Unmarshal([]byte(base), &baseMap)
80-
if err != nil {
81-
return "", fmt.Errorf("failed to unmarshal base input: %w", err)
79+
if base != "" {
80+
if err := json.Unmarshal([]byte(base), &baseMap); err != nil {
81+
return "", fmt.Errorf("failed to unmarshal base input: %w", err)
82+
}
8283
}
8384

84-
err = json.Unmarshal([]byte(overlay), &overlayMap)
85-
if err != nil {
85+
if err := json.Unmarshal([]byte(overlay), &overlayMap); err != nil {
8686
return "", fmt.Errorf("failed to unmarshal overlay input: %w", err)
8787
}
8888

pkg/prompt/prompt.go

+12-13
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,19 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
5252
func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) {
5353
var params struct {
5454
Message string `json:"message,omitempty"`
55-
Fields string `json:"fields,omitempty"`
55+
Fields types.Fields `json:"fields,omitempty"`
5656
Sensitive string `json:"sensitive,omitempty"`
5757
Metadata map[string]string `json:"metadata,omitempty"`
5858
}
5959
if err := json.Unmarshal([]byte(input), &params); err != nil {
6060
return "", err
6161
}
6262

63-
var fields []string
6463
for _, env := range envs {
6564
if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok {
66-
if params.Fields != "" {
67-
fields = strings.Split(params.Fields, ",")
68-
}
69-
7065
httpPrompt := types.Prompt{
7166
Message: params.Message,
72-
Fields: fields,
67+
Fields: params.Fields,
7368
Sensitive: params.Sensitive == "true",
7469
Metadata: params.Metadata,
7570
}
@@ -102,21 +97,25 @@ func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
10297
results := map[string]string{}
10398
for _, f := range req.Fields {
10499
var (
105-
value string
106-
msg = f
100+
value string
101+
msg = f.Name
102+
sensitive = req.Sensitive
107103
)
104+
if f.Sensitive != nil {
105+
sensitive = *f.Sensitive
106+
}
108107
if len(req.Fields) == 1 && req.Message != "" {
109108
msg = req.Message
110109
}
111-
if req.Sensitive {
112-
err = survey.AskOne(&survey.Password{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
110+
if sensitive {
111+
err = survey.AskOne(&survey.Password{Message: msg, Help: f.Description}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
113112
} else {
114-
err = survey.AskOne(&survey.Input{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
113+
err = survey.AskOne(&survey.Input{Message: msg, Help: f.Description}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
115114
}
116115
if err != nil {
117116
return "", err
118117
}
119-
results[f] = value
118+
results[f.Name] = value
120119
}
121120

122121
resultsStr, err := json.Marshal(results)

pkg/types/prompt.go

+62-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,74 @@
11
package types
22

3+
import (
4+
"encoding/json"
5+
"strings"
6+
)
7+
38
const (
49
PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL"
510
PromptTokenEnvVar = "GPTSCRIPT_PROMPT_TOKEN"
611
)
712

813
type Prompt struct {
914
Message string `json:"message,omitempty"`
10-
Fields []string `json:"fields,omitempty"`
15+
Fields Fields `json:"fields,omitempty"`
1116
Sensitive bool `json:"sensitive,omitempty"`
1217
Metadata map[string]string `json:"metadata,omitempty"`
1318
}
19+
20+
type Field struct {
21+
Name string `json:"name,omitempty"`
22+
Sensitive *bool `json:"sensitive,omitempty"`
23+
Description string `json:"description,omitempty"`
24+
}
25+
26+
type Fields []Field
27+
28+
// UnmarshalJSON will unmarshal the corresponding JSON object for Fields,
29+
// or a comma-separated strings (for backwards compatibility).
30+
func (f *Fields) UnmarshalJSON(b []byte) error {
31+
if len(b) == 0 || f == nil {
32+
return nil
33+
}
34+
35+
if b[0] == '[' {
36+
var arr []Field
37+
if err := json.Unmarshal(b, &arr); err != nil {
38+
return err
39+
}
40+
*f = arr
41+
return nil
42+
}
43+
44+
var fields string
45+
if err := json.Unmarshal(b, &fields); err != nil {
46+
return err
47+
}
48+
49+
if fields != "" {
50+
fieldsArr := strings.Split(fields, ",")
51+
*f = make([]Field, 0, len(fieldsArr))
52+
for _, field := range fieldsArr {
53+
*f = append(*f, Field{Name: strings.TrimSpace(field)})
54+
}
55+
}
56+
57+
return nil
58+
}
59+
60+
type field *Field
61+
62+
// UnmarshalJSON will unmarshal the corresponding JSON object for a Field,
63+
// or a string (for backwards compatibility).
64+
func (f *Field) UnmarshalJSON(b []byte) error {
65+
if len(b) == 0 || f == nil {
66+
return nil
67+
}
68+
69+
if b[0] == '{' {
70+
return json.Unmarshal(b, field(f))
71+
}
72+
73+
return json.Unmarshal(b, &f.Name)
74+
}

pkg/types/prompt_test.go

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package types
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
func TestFieldUnmarshalJSON(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
input []byte
12+
expected Field
13+
expectErr bool
14+
}{
15+
{
16+
name: "valid single Field object JSON",
17+
input: []byte(`{"name":"field1","sensitive":true,"description":"A test field"}`),
18+
expected: Field{Name: "field1", Sensitive: boolPtr(true), Description: "A test field"},
19+
expectErr: false,
20+
},
21+
{
22+
name: "valid Field name as string",
23+
input: []byte(`"field1"`),
24+
expected: Field{Name: "field1"},
25+
expectErr: false,
26+
},
27+
{
28+
name: "empty input",
29+
input: []byte(``),
30+
expected: Field{},
31+
expectErr: false,
32+
},
33+
{
34+
name: "invalid JSON object",
35+
input: []byte(`{"name":"field1","sensitive":"not_boolean"}`),
36+
expected: Field{Name: "field1", Sensitive: new(bool)},
37+
expectErr: true,
38+
},
39+
{
40+
name: "extra unknown fields in JSON object",
41+
input: []byte(`{"name":"field1","unknown":"field","sensitive":false}`),
42+
expected: Field{Name: "field1", Sensitive: boolPtr(false)},
43+
expectErr: false,
44+
},
45+
{
46+
name: "malformed JSON",
47+
input: []byte(`{"name":"field1","sensitive":true`),
48+
expected: Field{},
49+
expectErr: true,
50+
},
51+
}
52+
53+
for _, tt := range tests {
54+
t.Run(tt.name, func(t *testing.T) {
55+
var field Field
56+
err := field.UnmarshalJSON(tt.input)
57+
if (err != nil) != tt.expectErr {
58+
t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr)
59+
}
60+
if !reflect.DeepEqual(field, tt.expected) {
61+
t.Errorf("UnmarshalJSON() = %v, expected %v", field, tt.expected)
62+
}
63+
})
64+
}
65+
}
66+
67+
func TestFieldsUnmarshalJSON(t *testing.T) {
68+
tests := []struct {
69+
name string
70+
input []byte
71+
expected Fields
72+
expectErr bool
73+
}{
74+
{
75+
name: "empty input",
76+
input: nil,
77+
expected: nil,
78+
expectErr: false,
79+
},
80+
{
81+
name: "nil pointer",
82+
input: nil,
83+
expected: nil,
84+
expectErr: false,
85+
},
86+
{
87+
name: "valid JSON array",
88+
input: []byte(`[{"Name":"field1"},{"Name":"field2"}]`),
89+
expected: Fields{{Name: "field1"}, {Name: "field2"}},
90+
expectErr: false,
91+
},
92+
{
93+
name: "single string input",
94+
input: []byte(`"field1,field2,field3"`),
95+
expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}},
96+
expectErr: false,
97+
},
98+
{
99+
name: "trim spaces in single string input",
100+
input: []byte(`"field1, field2 , field3 "`),
101+
expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}},
102+
expectErr: false,
103+
},
104+
{
105+
name: "invalid JSON array",
106+
input: []byte(`[{"Name":"field1"},{"Name":field2}]`),
107+
expected: nil,
108+
expectErr: true,
109+
},
110+
{
111+
name: "invalid single string",
112+
input: []byte(`1234`),
113+
expected: nil,
114+
expectErr: true,
115+
},
116+
{
117+
name: "empty array",
118+
input: []byte(`[]`),
119+
expected: Fields{},
120+
expectErr: false,
121+
},
122+
{
123+
name: "empty string",
124+
input: []byte(`""`),
125+
expected: nil,
126+
expectErr: false,
127+
},
128+
}
129+
130+
for _, tt := range tests {
131+
t.Run(tt.name, func(t *testing.T) {
132+
var fields Fields
133+
err := fields.UnmarshalJSON(tt.input)
134+
if (err != nil) != tt.expectErr {
135+
t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr)
136+
}
137+
if !reflect.DeepEqual(fields, tt.expected) {
138+
t.Errorf("UnmarshalJSON() = %v, expected %v", fields, tt.expected)
139+
}
140+
})
141+
}
142+
}

0 commit comments

Comments
 (0)