@@ -132,7 +132,7 @@ func loadLocal(base *source, name string) (*source, bool, error) {
132
132
}, true , nil
133
133
}
134
134
135
- func loadProgram (data []byte , into * types.Program , targetToolName string ) (types.Tool , error ) {
135
+ func loadProgram (data []byte , into * types.Program , targetToolName , defaultModel string ) (types.Tool , error ) {
136
136
var ext types.Program
137
137
138
138
if err := json .Unmarshal (data [len (assemble .Header ):], & ext ); err != nil {
@@ -141,7 +141,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
141
141
142
142
into .ToolSet = make (map [string ]types.Tool , len (ext .ToolSet ))
143
143
for k , v := range ext .ToolSet {
144
- if builtinTool , ok := builtin .Builtin ( k ); ok {
144
+ if builtinTool , ok := builtin .BuiltinWithDefaultModel ( k , defaultModel ); ok {
145
145
v = builtinTool
146
146
}
147
147
into .ToolSet [k ] = v
@@ -186,11 +186,11 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
186
186
return openAPIDocument
187
187
}
188
188
189
- func readTool (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , targetToolName string ) ([]types.Tool , error ) {
189
+ func readTool (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , targetToolName , defaultModel string ) ([]types.Tool , error ) {
190
190
data := base .Content
191
191
192
192
if bytes .HasPrefix (data , assemble .Header ) {
193
- tool , err := loadProgram (data , prg , targetToolName )
193
+ tool , err := loadProgram (data , prg , targetToolName , defaultModel )
194
194
if err != nil {
195
195
return nil , err
196
196
}
@@ -310,17 +310,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
310
310
localTools [strings .ToLower (tool .Parameters .Name )] = tool
311
311
}
312
312
313
- return linkAll (ctx , cache , prg , base , targetTools , localTools )
313
+ return linkAll (ctx , cache , prg , base , targetTools , localTools , defaultModel )
314
314
}
315
315
316
- func linkAll (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tools []types.Tool , localTools types.ToolSet ) (result []types.Tool , _ error ) {
316
+ func linkAll (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tools []types.Tool , localTools types.ToolSet , defaultModel string ) (result []types.Tool , _ error ) {
317
317
localToolsMapping := make (map [string ]string , len (tools ))
318
318
for _ , localTool := range localTools {
319
319
localToolsMapping [strings .ToLower (localTool .Parameters .Name )] = localTool .ID
320
320
}
321
321
322
322
for _ , tool := range tools {
323
- tool , err := link (ctx , cache , prg , base , tool , localTools , localToolsMapping )
323
+ tool , err := link (ctx , cache , prg , base , tool , localTools , localToolsMapping , defaultModel )
324
324
if err != nil {
325
325
return nil , err
326
326
}
@@ -329,7 +329,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base
329
329
return
330
330
}
331
331
332
- func link (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tool types.Tool , localTools types.ToolSet , localToolsMapping map [string ]string ) (types.Tool , error ) {
332
+ func link (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tool types.Tool , localTools types.ToolSet , localToolsMapping map [string ]string , defaultModel string ) (types.Tool , error ) {
333
333
if existing , ok := prg .ToolSet [tool .ID ]; ok {
334
334
return existing , nil
335
335
}
@@ -354,7 +354,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
354
354
linkedTool = existing
355
355
} else {
356
356
var err error
357
- linkedTool , err = link (ctx , cache , prg , base , localTool , localTools , localToolsMapping )
357
+ linkedTool , err = link (ctx , cache , prg , base , localTool , localTools , localToolsMapping , defaultModel )
358
358
if err != nil {
359
359
return types.Tool {}, fmt .Errorf ("failed linking %s at %s: %w" , targetToolName , base , err )
360
360
}
@@ -364,7 +364,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
364
364
toolNames [targetToolName ] = struct {}{}
365
365
} else {
366
366
toolName , subTool := types .SplitToolRef (targetToolName )
367
- resolvedTools , err := resolve (ctx , cache , prg , base , toolName , subTool )
367
+ resolvedTools , err := resolve (ctx , cache , prg , base , toolName , subTool , defaultModel )
368
368
if err != nil {
369
369
return types.Tool {}, fmt .Errorf ("failed resolving %s from %s: %w" , targetToolName , base , err )
370
370
}
@@ -376,6 +376,10 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
376
376
377
377
tool .LocalTools = localToolsMapping
378
378
379
+ if defaultModel != "" && tool .ModelName == "" {
380
+ tool .ModelName = defaultModel
381
+ }
382
+
379
383
tool = builtin .SetDefaults (tool )
380
384
prg .ToolSet [tool .ID ] = tool
381
385
@@ -405,7 +409,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
405
409
Path : locationPath ,
406
410
Name : locationName ,
407
411
Location : opt .Location ,
408
- }, subToolName )
412
+ }, subToolName , opt . DefaultModel )
409
413
if err != nil {
410
414
return types.Program {}, err
411
415
}
@@ -414,20 +418,26 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
414
418
}
415
419
416
420
type Options struct {
417
- Cache * cache.Client
418
- Location string
421
+ Cache * cache.Client
422
+ Location string
423
+ DefaultModel string
419
424
}
420
425
421
426
func complete (opts ... Options ) (result Options ) {
422
427
for _ , opt := range opts {
423
428
result .Cache = types .FirstSet (opt .Cache , result .Cache )
424
429
result .Location = types .FirstSet (opt .Location , result .Location )
430
+ result .DefaultModel = types .FirstSet (opt .DefaultModel , result .DefaultModel )
425
431
}
426
432
427
433
if result .Location == "" {
428
434
result .Location = "inline"
429
435
}
430
436
437
+ if result .DefaultModel == "" {
438
+ result .DefaultModel = builtin .GetDefaultModel ()
439
+ }
440
+
431
441
return
432
442
}
433
443
@@ -451,17 +461,17 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty
451
461
Name : name ,
452
462
ToolSet : types.ToolSet {},
453
463
}
454
- tools , err := resolve (ctx , opt .Cache , & prg , & source {}, name , subToolName )
464
+ tools , err := resolve (ctx , opt .Cache , & prg , & source {}, name , subToolName , opt . DefaultModel )
455
465
if err != nil {
456
466
return types.Program {}, err
457
467
}
458
468
prg .EntryToolID = tools [0 ].ID
459
469
return prg , nil
460
470
}
461
471
462
- func resolve (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , name , subTool string ) ([]types.Tool , error ) {
472
+ func resolve (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , name , subTool , defaultModel string ) ([]types.Tool , error ) {
463
473
if subTool == "" {
464
- t , ok := builtin .Builtin (name )
474
+ t , ok := builtin .BuiltinWithDefaultModel (name , defaultModel )
465
475
if ok {
466
476
prg .ToolSet [t .ID ] = t
467
477
return []types.Tool {t }, nil
@@ -473,7 +483,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base
473
483
return nil , err
474
484
}
475
485
476
- result , err := readTool (ctx , cache , prg , s , subTool )
486
+ result , err := readTool (ctx , cache , prg , s , subTool , defaultModel )
477
487
if err != nil {
478
488
return nil , err
479
489
}
0 commit comments