Skip to content

Commit d1d6179

Browse files
committed
chore: sdkserver: update dataset methods for the rewrite
Signed-off-by: Grant Linville <[email protected]>
1 parent d21c001 commit d1d6179

File tree

2 files changed

+23
-165
lines changed

2 files changed

+23
-165
lines changed

pkg/sdkserver/datasets.go

+23-163
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,21 @@ import (
1111
)
1212

1313
func (s *server) getDatasetTool(req datasetRequest) string {
14-
if req.DatasetToolRepo != "" {
15-
return req.DatasetToolRepo
14+
if req.DatasetTool != "" {
15+
return req.DatasetTool
1616
}
1717

1818
return s.datasetTool
1919
}
2020

2121
type datasetRequest struct {
22-
Input string `json:"input"`
23-
WorkspaceID string `json:"workspaceID"`
24-
DatasetToolRepo string `json:"datasetToolRepo"`
25-
Env []string `json:"env"`
22+
Input string `json:"input"`
23+
DatasetTool string `json:"datasetTool"`
24+
Env []string `json:"env"`
2625
}
2726

2827
func (r datasetRequest) validate(requireInput bool) error {
29-
if r.WorkspaceID == "" {
30-
return fmt.Errorf("workspaceID is required")
31-
} else if requireInput && r.Input == "" {
28+
if requireInput && r.Input == "" {
3229
return fmt.Errorf("input is required")
3330
} else if len(r.Env) == 0 {
3431
return fmt.Errorf("env is required")
@@ -38,10 +35,9 @@ func (r datasetRequest) validate(requireInput bool) error {
3835

3936
func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
4037
opts := gptscript.Options{
41-
Cache: o.Cache,
42-
Monitor: o.Monitor,
43-
Runner: o.Runner,
44-
Workspace: r.WorkspaceID,
38+
Cache: o.Cache,
39+
Monitor: o.Monitor,
40+
Runner: o.Runner,
4541
}
4642
return opts
4743
}
@@ -84,148 +80,19 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
8480
writeResponse(logger, w, map[string]any{"stdout": result})
8581
}
8682

87-
type createDatasetArgs struct {
88-
Name string `json:"datasetName"`
89-
Description string `json:"datasetDescription"`
90-
}
91-
92-
func (a createDatasetArgs) validate() error {
93-
if a.Name == "" {
94-
return fmt.Errorf("datasetName is required")
95-
}
96-
return nil
97-
}
98-
99-
func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
100-
logger := gcontext.GetLogger(r.Context())
101-
102-
var req datasetRequest
103-
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
104-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
105-
return
106-
}
107-
108-
if err := req.validate(true); err != nil {
109-
writeError(logger, w, http.StatusBadRequest, err)
110-
return
111-
}
112-
113-
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
114-
if err != nil {
115-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
116-
return
117-
}
118-
119-
var args createDatasetArgs
120-
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
121-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
122-
return
123-
}
124-
125-
if err := args.validate(); err != nil {
126-
writeError(logger, w, http.StatusBadRequest, err)
127-
return
128-
}
129-
130-
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Create Dataset", loader.Options{
131-
Cache: g.Cache,
132-
})
133-
134-
if err != nil {
135-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
136-
return
137-
}
138-
139-
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
140-
if err != nil {
141-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
142-
return
143-
}
144-
145-
writeResponse(logger, w, map[string]any{"stdout": result})
146-
}
147-
148-
type addDatasetElementArgs struct {
149-
DatasetID string `json:"datasetID"`
150-
ElementName string `json:"elementName"`
151-
ElementDescription string `json:"elementDescription"`
152-
ElementContent string `json:"elementContent"`
153-
}
154-
155-
func (a addDatasetElementArgs) validate() error {
156-
if a.DatasetID == "" {
157-
return fmt.Errorf("datasetID is required")
158-
}
159-
if a.ElementName == "" {
160-
return fmt.Errorf("elementName is required")
161-
}
162-
if a.ElementContent == "" {
163-
return fmt.Errorf("elementContent is required")
164-
}
165-
return nil
166-
}
167-
168-
func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
169-
logger := gcontext.GetLogger(r.Context())
170-
171-
var req datasetRequest
172-
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
173-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
174-
return
175-
}
176-
177-
if err := req.validate(true); err != nil {
178-
writeError(logger, w, http.StatusBadRequest, err)
179-
return
180-
}
181-
182-
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
183-
if err != nil {
184-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
185-
return
186-
}
187-
188-
var args addDatasetElementArgs
189-
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
190-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
191-
return
192-
}
193-
194-
if err := args.validate(); err != nil {
195-
writeError(logger, w, http.StatusBadRequest, err)
196-
return
197-
}
198-
199-
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Element", loader.Options{
200-
Cache: g.Cache,
201-
})
202-
if err != nil {
203-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
204-
return
205-
}
206-
207-
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
208-
if err != nil {
209-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
210-
return
211-
}
212-
213-
writeResponse(logger, w, map[string]any{"stdout": result})
214-
}
215-
21683
type addDatasetElementsArgs struct {
217-
DatasetID string `json:"datasetID"`
218-
Elements []struct {
219-
Name string `json:"name"`
220-
Description string `json:"description"`
221-
Contents string `json:"contents"`
222-
}
84+
DatasetID string `json:"datasetID"`
85+
Name string `json:"name"`
86+
Description string `json:"description"`
87+
Elements []struct {
88+
Name string `json:"name"`
89+
Description string `json:"description"`
90+
Contents string `json:"contents"`
91+
BinaryContents []byte `json:"binaryContents"`
92+
} `json:"elements"`
22393
}
22494

22595
func (a addDatasetElementsArgs) validate() error {
226-
if a.DatasetID == "" {
227-
return fmt.Errorf("datasetID is required")
228-
}
22996
if len(a.Elements) == 0 {
23097
return fmt.Errorf("elements is required")
23198
}
@@ -271,13 +138,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
271138
return
272139
}
273140

274-
elementsJSON, err := json.Marshal(args.Elements)
275-
if err != nil {
276-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err))
277-
return
278-
}
279-
280-
result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON)))
141+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
281142
if err != nil {
282143
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
283144
return
@@ -347,15 +208,14 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
347208

348209
type getDatasetElementArgs struct {
349210
DatasetID string `json:"datasetID"`
350-
Element string `json:"element"`
211+
Name string `json:"name"`
351212
}
352213

353214
func (a getDatasetElementArgs) validate() error {
354215
if a.DatasetID == "" {
355216
return fmt.Errorf("datasetID is required")
356-
}
357-
if a.Element == "" {
358-
return fmt.Errorf("element is required")
217+
} else if a.Name == "" {
218+
return fmt.Errorf("name is required")
359219
}
360220
return nil
361221
}
@@ -391,7 +251,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
391251
return
392252
}
393253

394-
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element SDK", loader.Options{
254+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element", loader.Options{
395255
Cache: g.Cache,
396256
})
397257
if err != nil {

pkg/sdkserver/routes.go

-2
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,8 @@ func (s *server) addRoutes(mux *http.ServeMux) {
6969
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
7070

7171
mux.HandleFunc("POST /datasets", s.listDatasets)
72-
mux.HandleFunc("POST /datasets/create", s.createDataset)
7372
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
7473
mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
75-
mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
7674
mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements)
7775

7876
mux.HandleFunc("POST /workspaces/create", s.createWorkspace)

0 commit comments

Comments
 (0)