Skip to content

Commit d21c001

Browse files
committed
chore: allow setting of dataset tool in SDK server config
Signed-off-by: Donnie Adams <[email protected]>
1 parent 2a9f664 commit d21c001

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

pkg/cli/sdk_server.go

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
type SDKServer struct {
1313
*GPTScript
14+
DatasetTool string `usage:"Tool to use for datasets"`
1415
WorkspaceTool string `usage:"Tool to use for workspace"`
1516
}
1617

@@ -38,6 +39,7 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error {
3839
Options: opts,
3940
ListenAddress: c.ListenAddress,
4041
Debug: c.Debug,
42+
DatasetTool: c.DatasetTool,
4143
WorkspaceTool: c.WorkspaceTool,
4244
})
4345
}

pkg/sdkserver/datasets.go

+14-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ import (
1010
"github.com/gptscript-ai/gptscript/pkg/loader"
1111
)
1212

13+
func (s *server) getDatasetTool(req datasetRequest) string {
14+
if req.DatasetToolRepo != "" {
15+
return req.DatasetToolRepo
16+
}
17+
18+
return s.datasetTool
19+
}
20+
1321
type datasetRequest struct {
1422
Input string `json:"input"`
1523
WorkspaceID string `json:"workspaceID"`
@@ -38,13 +46,6 @@ func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
3846
return opts
3947
}
4048

41-
func (r datasetRequest) getToolRepo() string {
42-
if r.DatasetToolRepo != "" {
43-
return r.DatasetToolRepo
44-
}
45-
return "github.com/otto8-ai/datasets"
46-
}
47-
4849
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
4950
logger := gcontext.GetLogger(r.Context())
5051

@@ -65,7 +66,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
6566
return
6667
}
6768

68-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
69+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "List Datasets", loader.Options{
6970
Cache: g.Cache,
7071
})
7172

@@ -126,7 +127,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
126127
return
127128
}
128129

129-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Create Dataset", loader.Options{
130+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Create Dataset", loader.Options{
130131
Cache: g.Cache,
131132
})
132133

@@ -195,7 +196,7 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
195196
return
196197
}
197198

198-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Element", loader.Options{
199+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Element", loader.Options{
199200
Cache: g.Cache,
200201
})
201202
if err != nil {
@@ -262,7 +263,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
262263
return
263264
}
264265

265-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Elements", loader.Options{
266+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Elements", loader.Options{
266267
Cache: g.Cache,
267268
})
268269
if err != nil {
@@ -327,7 +328,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
327328
return
328329
}
329330

330-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Elements", loader.Options{
331+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "List Elements", loader.Options{
331332
Cache: g.Cache,
332333
})
333334
if err != nil {
@@ -390,7 +391,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
390391
return
391392
}
392393

393-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element SDK", loader.Options{
394+
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element SDK", loader.Options{
394395
Cache: g.Cache,
395396
})
396397
if err != nil {

pkg/sdkserver/routes.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ import (
2626
)
2727

2828
type server struct {
29-
gptscriptOpts gptscript.Options
30-
address, token string
31-
workspaceTool string
32-
client *gptscript.GPTScript
33-
events *broadcaster.Broadcaster[event]
29+
gptscriptOpts gptscript.Options
30+
address, token string
31+
datasetTool, workspaceTool string
32+
client *gptscript.GPTScript
33+
events *broadcaster.Broadcaster[event]
3434

3535
runtimeManager engine.RuntimeManager
3636

pkg/sdkserver/server.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import (
2626
type Options struct {
2727
gptscript.Options
2828

29-
ListenAddress string
30-
WorkspaceTool string
31-
Debug bool
32-
DisableServerErrorLogging bool
29+
ListenAddress string
30+
DatasetTool, WorkspaceTool string
31+
Debug bool
32+
DisableServerErrorLogging bool
3333
}
3434

3535
// Run will start the server and block until the server is shut down.
@@ -108,6 +108,7 @@ func run(ctx context.Context, listener net.Listener, opts Options) error {
108108
gptscriptOpts: opts.Options,
109109
address: listener.Addr().String(),
110110
token: token,
111+
datasetTool: opts.DatasetTool,
111112
workspaceTool: opts.WorkspaceTool,
112113
client: g,
113114
events: events,
@@ -159,6 +160,7 @@ func complete(opts ...Options) Options {
159160
for _, opt := range opts {
160161
result.Options = gptscript.Complete(result.Options, opt.Options)
161162
result.ListenAddress = types.FirstSet(opt.ListenAddress, result.ListenAddress)
163+
result.DatasetTool = types.FirstSet(opt.DatasetTool, result.DatasetTool)
162164
result.WorkspaceTool = types.FirstSet(opt.WorkspaceTool, result.WorkspaceTool)
163165
result.Debug = types.FirstSet(opt.Debug, result.Debug)
164166
result.DisableServerErrorLogging = types.FirstSet(opt.DisableServerErrorLogging, result.DisableServerErrorLogging)
@@ -171,6 +173,9 @@ func complete(opts ...Options) Options {
171173
if result.WorkspaceTool == "" {
172174
result.WorkspaceTool = "github.com/gptscript-ai/workspace-provider"
173175
}
176+
if result.DatasetTool == "" {
177+
result.DatasetTool = "github.com/otto8-ai/datasets"
178+
}
174179

175180
return result
176181
}

0 commit comments

Comments
 (0)