Skip to content

Commit 0bb7bbf

Browse files
committed
workspaceID is no longer part of tool input
Signed-off-by: Grant Linville <[email protected]>
1 parent 5dab330 commit 0bb7bbf

File tree

1 file changed

+21
-48
lines changed

1 file changed

+21
-48
lines changed

pkg/sdkserver/datasets.go

+21-48
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ import (
1212

1313
type datasetRequest struct {
1414
Input string `json:"input"`
15+
WorkspaceID string `json:"workspaceID"`
1516
DatasetTool string `json:"datasetTool"`
1617
Env []string `json:"env"`
1718
}
1819

19-
func (r datasetRequest) validate() error {
20-
if r.Input == "" {
20+
func (r datasetRequest) validate(requireInput bool) error {
21+
if requireInput && r.Input == "" {
2122
return fmt.Errorf("input is required")
23+
} else if r.WorkspaceID == "" {
24+
return fmt.Errorf("workspaceID is required")
2225
} else if len(r.Env) == 0 {
2326
return fmt.Errorf("env is required")
2427
}
@@ -27,9 +30,10 @@ func (r datasetRequest) validate() error {
2730

2831
func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
2932
opts := gptscript.Options{
30-
Cache: o.Cache,
31-
Monitor: o.Monitor,
32-
Runner: o.Runner,
33+
Cache: o.Cache,
34+
Monitor: o.Monitor,
35+
Runner: o.Runner,
36+
Workspace: r.WorkspaceID,
3337
}
3438
return opts
3539
}
@@ -41,17 +45,6 @@ func (r datasetRequest) getToolRepo() string {
4145
return "github.com/otto8-ai/datasets"
4246
}
4347

44-
type listDatasetsArgs struct {
45-
WorkspaceID string `json:"workspaceID"`
46-
}
47-
48-
func (a listDatasetsArgs) validate() error {
49-
if a.WorkspaceID == "" {
50-
return fmt.Errorf("workspaceID is required")
51-
}
52-
return nil
53-
}
54-
5548
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
5649
logger := gcontext.GetLogger(r.Context())
5750

@@ -61,7 +54,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
6154
return
6255
}
6356

64-
if err := req.validate(); err != nil {
57+
if err := req.validate(false); err != nil {
6558
writeError(logger, w, http.StatusBadRequest, err)
6659
return
6760
}
@@ -72,17 +65,6 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
7265
return
7366
}
7467

75-
var args listDatasetsArgs
76-
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
77-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
78-
return
79-
}
80-
81-
if err := args.validate(); err != nil {
82-
writeError(logger, w, http.StatusBadRequest, err)
83-
return
84-
}
85-
8668
prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
8769
Cache: g.Cache,
8870
})
@@ -102,9 +84,8 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
10284
}
10385

10486
type addDatasetElementsArgs struct {
105-
WorkspaceID string `json:"workspaceID"`
106-
DatasetID string `json:"datasetID"`
107-
Elements []struct {
87+
DatasetID string `json:"datasetID"`
88+
Elements []struct {
10889
Name string `json:"name"`
10990
Description string `json:"description"`
11091
Contents string `json:"contents"`
@@ -113,9 +94,7 @@ type addDatasetElementsArgs struct {
11394
}
11495

11596
func (a addDatasetElementsArgs) validate() error {
116-
if a.WorkspaceID == "" {
117-
return fmt.Errorf("workspaceID is required")
118-
} else if len(a.Elements) == 0 {
97+
if len(a.Elements) == 0 {
11998
return fmt.Errorf("elements is required")
12099
}
121100
return nil
@@ -130,7 +109,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
130109
return
131110
}
132111

133-
if err := req.validate(); err != nil {
112+
if err := req.validate(true); err != nil {
134113
writeError(logger, w, http.StatusBadRequest, err)
135114
return
136115
}
@@ -170,14 +149,11 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
170149
}
171150

172151
type listDatasetElementsArgs struct {
173-
WorkspaceID string `json:"workspaceID"`
174-
DatasetID string `json:"datasetID"`
152+
DatasetID string `json:"datasetID"`
175153
}
176154

177155
func (a listDatasetElementsArgs) validate() error {
178-
if a.WorkspaceID == "" {
179-
return fmt.Errorf("workspaceID is required")
180-
} else if a.DatasetID == "" {
156+
if a.DatasetID == "" {
181157
return fmt.Errorf("datasetID is required")
182158
}
183159
return nil
@@ -192,7 +168,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
192168
return
193169
}
194170

195-
if err := req.validate(); err != nil {
171+
if err := req.validate(true); err != nil {
196172
writeError(logger, w, http.StatusBadRequest, err)
197173
return
198174
}
@@ -232,15 +208,12 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
232208
}
233209

234210
type getDatasetElementArgs struct {
235-
WorkspaceID string `json:"workspaceID"`
236-
DatasetID string `json:"datasetID"`
237-
Name string `json:"name"`
211+
DatasetID string `json:"datasetID"`
212+
Name string `json:"name"`
238213
}
239214

240215
func (a getDatasetElementArgs) validate() error {
241-
if a.WorkspaceID == "" {
242-
return fmt.Errorf("workspaceID is required")
243-
} else if a.DatasetID == "" {
216+
if a.DatasetID == "" {
244217
return fmt.Errorf("datasetID is required")
245218
} else if a.Name == "" {
246219
return fmt.Errorf("name is required")
@@ -257,7 +230,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
257230
return
258231
}
259232

260-
if err := req.validate(); err != nil {
233+
if err := req.validate(true); err != nil {
261234
writeError(logger, w, http.StatusBadRequest, err)
262235
return
263236
}

0 commit comments

Comments
 (0)