@@ -12,13 +12,16 @@ import (
12
12
13
13
type datasetRequest struct {
14
14
Input string `json:"input"`
15
+ WorkspaceID string `json:"workspaceID"`
15
16
DatasetTool string `json:"datasetTool"`
16
17
Env []string `json:"env"`
17
18
}
18
19
19
- func (r datasetRequest ) validate () error {
20
- if r .Input == "" {
20
+ func (r datasetRequest ) validate (requireInput bool ) error {
21
+ if requireInput && r .Input == "" {
21
22
return fmt .Errorf ("input is required" )
23
+ } else if r .WorkspaceID == "" {
24
+ return fmt .Errorf ("workspaceID is required" )
22
25
} else if len (r .Env ) == 0 {
23
26
return fmt .Errorf ("env is required" )
24
27
}
@@ -27,9 +30,10 @@ func (r datasetRequest) validate() error {
27
30
28
31
func (r datasetRequest ) opts (o gptscript.Options ) gptscript.Options {
29
32
opts := gptscript.Options {
30
- Cache : o .Cache ,
31
- Monitor : o .Monitor ,
32
- Runner : o .Runner ,
33
+ Cache : o .Cache ,
34
+ Monitor : o .Monitor ,
35
+ Runner : o .Runner ,
36
+ Workspace : r .WorkspaceID ,
33
37
}
34
38
return opts
35
39
}
@@ -41,17 +45,6 @@ func (r datasetRequest) getToolRepo() string {
41
45
return "github.com/otto8-ai/datasets"
42
46
}
43
47
44
- type listDatasetsArgs struct {
45
- WorkspaceID string `json:"workspaceID"`
46
- }
47
-
48
- func (a listDatasetsArgs ) validate () error {
49
- if a .WorkspaceID == "" {
50
- return fmt .Errorf ("workspaceID is required" )
51
- }
52
- return nil
53
- }
54
-
55
48
func (s * server ) listDatasets (w http.ResponseWriter , r * http.Request ) {
56
49
logger := gcontext .GetLogger (r .Context ())
57
50
@@ -61,7 +54,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
61
54
return
62
55
}
63
56
64
- if err := req .validate (); err != nil {
57
+ if err := req .validate (false ); err != nil {
65
58
writeError (logger , w , http .StatusBadRequest , err )
66
59
return
67
60
}
@@ -72,17 +65,6 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
72
65
return
73
66
}
74
67
75
- var args listDatasetsArgs
76
- if err := json .Unmarshal ([]byte (req .Input ), & args ); err != nil {
77
- writeError (logger , w , http .StatusBadRequest , fmt .Errorf ("failed to unmarshal input: %w" , err ))
78
- return
79
- }
80
-
81
- if err := args .validate (); err != nil {
82
- writeError (logger , w , http .StatusBadRequest , err )
83
- return
84
- }
85
-
86
68
prg , err := loader .Program (r .Context (), req .getToolRepo (), "List Datasets" , loader.Options {
87
69
Cache : g .Cache ,
88
70
})
@@ -102,9 +84,8 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
102
84
}
103
85
104
86
type addDatasetElementsArgs struct {
105
- WorkspaceID string `json:"workspaceID"`
106
- DatasetID string `json:"datasetID"`
107
- Elements []struct {
87
+ DatasetID string `json:"datasetID"`
88
+ Elements []struct {
108
89
Name string `json:"name"`
109
90
Description string `json:"description"`
110
91
Contents string `json:"contents"`
@@ -113,9 +94,7 @@ type addDatasetElementsArgs struct {
113
94
}
114
95
115
96
func (a addDatasetElementsArgs ) validate () error {
116
- if a .WorkspaceID == "" {
117
- return fmt .Errorf ("workspaceID is required" )
118
- } else if len (a .Elements ) == 0 {
97
+ if len (a .Elements ) == 0 {
119
98
return fmt .Errorf ("elements is required" )
120
99
}
121
100
return nil
@@ -130,7 +109,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
130
109
return
131
110
}
132
111
133
- if err := req .validate (); err != nil {
112
+ if err := req .validate (true ); err != nil {
134
113
writeError (logger , w , http .StatusBadRequest , err )
135
114
return
136
115
}
@@ -170,14 +149,11 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
170
149
}
171
150
172
151
type listDatasetElementsArgs struct {
173
- WorkspaceID string `json:"workspaceID"`
174
- DatasetID string `json:"datasetID"`
152
+ DatasetID string `json:"datasetID"`
175
153
}
176
154
177
155
func (a listDatasetElementsArgs ) validate () error {
178
- if a .WorkspaceID == "" {
179
- return fmt .Errorf ("workspaceID is required" )
180
- } else if a .DatasetID == "" {
156
+ if a .DatasetID == "" {
181
157
return fmt .Errorf ("datasetID is required" )
182
158
}
183
159
return nil
@@ -192,7 +168,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
192
168
return
193
169
}
194
170
195
- if err := req .validate (); err != nil {
171
+ if err := req .validate (true ); err != nil {
196
172
writeError (logger , w , http .StatusBadRequest , err )
197
173
return
198
174
}
@@ -232,15 +208,12 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
232
208
}
233
209
234
210
type getDatasetElementArgs struct {
235
- WorkspaceID string `json:"workspaceID"`
236
- DatasetID string `json:"datasetID"`
237
- Name string `json:"name"`
211
+ DatasetID string `json:"datasetID"`
212
+ Name string `json:"name"`
238
213
}
239
214
240
215
func (a getDatasetElementArgs ) validate () error {
241
- if a .WorkspaceID == "" {
242
- return fmt .Errorf ("workspaceID is required" )
243
- } else if a .DatasetID == "" {
216
+ if a .DatasetID == "" {
244
217
return fmt .Errorf ("datasetID is required" )
245
218
} else if a .Name == "" {
246
219
return fmt .Errorf ("name is required" )
@@ -257,7 +230,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
257
230
return
258
231
}
259
232
260
- if err := req .validate (); err != nil {
233
+ if err := req .validate (true ); err != nil {
261
234
writeError (logger , w , http .StatusBadRequest , err )
262
235
return
263
236
}
0 commit comments