Skip to content

Commit db48ed6

Browse files
committed
fix: add default model to the loader
The tool loader will set the models on the tools if none is set. The way that that happens works for the CLI, but is not compatible with the SDK. This change makes the default model logic work with the SDK server. Signed-off-by: Donnie Adams <[email protected]>
1 parent 61e6ded commit db48ed6

File tree

3 files changed

+37
-18
lines changed

3 files changed

+37
-18
lines changed

pkg/builtin/builtin.go

+5
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,15 @@ func ListTools() (result []types.Tool) {
277277
}
278278

279279
func Builtin(name string) (types.Tool, bool) {
280+
return BuiltinWithDefaultModel(name, "")
281+
}
282+
283+
func BuiltinWithDefaultModel(name, defaultModel string) (types.Tool, bool) {
280284
// Legacy syntax not used anymore
281285
name = strings.TrimSuffix(name, "?")
282286
t, ok := tools[name]
283287
t.Parameters.Name = name
288+
t.Parameters.ModelName = defaultModel
284289
t.ID = name
285290
t.Instructions = "#!" + name
286291
return SetDefaults(t), ok

pkg/loader/loader.go

+27-17
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func loadLocal(base *source, name string) (*source, bool, error) {
132132
}, true, nil
133133
}
134134

135-
func loadProgram(data []byte, into *types.Program, targetToolName string) (types.Tool, error) {
135+
func loadProgram(data []byte, into *types.Program, targetToolName, defaultModel string) (types.Tool, error) {
136136
var ext types.Program
137137

138138
if err := json.Unmarshal(data[len(assemble.Header):], &ext); err != nil {
@@ -141,7 +141,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
141141

142142
into.ToolSet = make(map[string]types.Tool, len(ext.ToolSet))
143143
for k, v := range ext.ToolSet {
144-
if builtinTool, ok := builtin.Builtin(k); ok {
144+
if builtinTool, ok := builtin.BuiltinWithDefaultModel(k, defaultModel); ok {
145145
v = builtinTool
146146
}
147147
into.ToolSet[k] = v
@@ -186,11 +186,11 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
186186
return openAPIDocument
187187
}
188188

189-
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) ([]types.Tool, error) {
189+
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) {
190190
data := base.Content
191191

192192
if bytes.HasPrefix(data, assemble.Header) {
193-
tool, err := loadProgram(data, prg, targetToolName)
193+
tool, err := loadProgram(data, prg, targetToolName, defaultModel)
194194
if err != nil {
195195
return nil, err
196196
}
@@ -310,17 +310,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
310310
localTools[strings.ToLower(tool.Parameters.Name)] = tool
311311
}
312312

313-
return linkAll(ctx, cache, prg, base, targetTools, localTools)
313+
return linkAll(ctx, cache, prg, base, targetTools, localTools, defaultModel)
314314
}
315315

316-
func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet) (result []types.Tool, _ error) {
316+
func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) {
317317
localToolsMapping := make(map[string]string, len(tools))
318318
for _, localTool := range localTools {
319319
localToolsMapping[strings.ToLower(localTool.Parameters.Name)] = localTool.ID
320320
}
321321

322322
for _, tool := range tools {
323-
tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping)
323+
tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping, defaultModel)
324324
if err != nil {
325325
return nil, err
326326
}
@@ -329,7 +329,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base
329329
return
330330
}
331331

332-
func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string) (types.Tool, error) {
332+
func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) {
333333
if existing, ok := prg.ToolSet[tool.ID]; ok {
334334
return existing, nil
335335
}
@@ -354,7 +354,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
354354
linkedTool = existing
355355
} else {
356356
var err error
357-
linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping)
357+
linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping, defaultModel)
358358
if err != nil {
359359
return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err)
360360
}
@@ -364,7 +364,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
364364
toolNames[targetToolName] = struct{}{}
365365
} else {
366366
toolName, subTool := types.SplitToolRef(targetToolName)
367-
resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool)
367+
resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool, defaultModel)
368368
if err != nil {
369369
return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err)
370370
}
@@ -376,6 +376,10 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
376376

377377
tool.LocalTools = localToolsMapping
378378

379+
if defaultModel != "" && tool.ModelName == "" {
380+
tool.ModelName = defaultModel
381+
}
382+
379383
tool = builtin.SetDefaults(tool)
380384
prg.ToolSet[tool.ID] = tool
381385

@@ -405,7 +409,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
405409
Path: locationPath,
406410
Name: locationName,
407411
Location: opt.Location,
408-
}, subToolName)
412+
}, subToolName, opt.DefaultModel)
409413
if err != nil {
410414
return types.Program{}, err
411415
}
@@ -414,20 +418,26 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
414418
}
415419

416420
type Options struct {
417-
Cache *cache.Client
418-
Location string
421+
Cache *cache.Client
422+
Location string
423+
DefaultModel string
419424
}
420425

421426
func complete(opts ...Options) (result Options) {
422427
for _, opt := range opts {
423428
result.Cache = types.FirstSet(opt.Cache, result.Cache)
424429
result.Location = types.FirstSet(opt.Location, result.Location)
430+
result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel)
425431
}
426432

427433
if result.Location == "" {
428434
result.Location = "inline"
429435
}
430436

437+
if result.DefaultModel == "" {
438+
result.DefaultModel = builtin.GetDefaultModel()
439+
}
440+
431441
return
432442
}
433443

@@ -451,17 +461,17 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty
451461
Name: name,
452462
ToolSet: types.ToolSet{},
453463
}
454-
tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName)
464+
tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName, opt.DefaultModel)
455465
if err != nil {
456466
return types.Program{}, err
457467
}
458468
prg.EntryToolID = tools[0].ID
459469
return prg, nil
460470
}
461471

462-
func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) ([]types.Tool, error) {
472+
func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) {
463473
if subTool == "" {
464-
t, ok := builtin.Builtin(name)
474+
t, ok := builtin.BuiltinWithDefaultModel(name, defaultModel)
465475
if ok {
466476
prg.ToolSet[t.ID] = t
467477
return []types.Tool{t}, nil
@@ -473,7 +483,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base
473483
return nil, err
474484
}
475485

476-
result, err := readTool(ctx, cache, prg, s, subTool)
486+
result, err := readTool(ctx, cache, prg, s, subTool, defaultModel)
477487
if err != nil {
478488
return nil, err
479489
}

pkg/sdkserver/run.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo
3232
}
3333
defer g.Close(false)
3434

35-
prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache})
35+
defaultModel := opts.OpenAI.DefaultModel
36+
if defaultModel == "" {
37+
defaultModel = s.gptscriptOpts.OpenAI.DefaultModel
38+
}
39+
prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache, DefaultModel: defaultModel})
3640
if err != nil {
3741
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
3842
return

0 commit comments

Comments
 (0)