@@ -132,7 +132,7 @@ func loadLocal(base *source, name string) (*source, bool, error) {
132132 }, true , nil
133133}
134134
135- func loadProgram (data []byte , into * types.Program , targetToolName string ) (types.Tool , error ) {
135+ func loadProgram (data []byte , into * types.Program , targetToolName , defaultModel string ) (types.Tool , error ) {
136136 var ext types.Program
137137
138138 if err := json .Unmarshal (data [len (assemble .Header ):], & ext ); err != nil {
@@ -141,7 +141,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
141141
142142 into .ToolSet = make (map [string ]types.Tool , len (ext .ToolSet ))
143143 for k , v := range ext .ToolSet {
144- if builtinTool , ok := builtin .Builtin ( k ); ok {
144+ if builtinTool , ok := builtin .BuiltinWithDefaultModel ( k , defaultModel ); ok {
145145 v = builtinTool
146146 }
147147 into .ToolSet [k ] = v
@@ -186,11 +186,11 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
186186 return openAPIDocument
187187}
188188
189- func readTool (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , targetToolName string ) ([]types.Tool , error ) {
189+ func readTool (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , targetToolName , defaultModel string ) ([]types.Tool , error ) {
190190 data := base .Content
191191
192192 if bytes .HasPrefix (data , assemble .Header ) {
193- tool , err := loadProgram (data , prg , targetToolName )
193+ tool , err := loadProgram (data , prg , targetToolName , defaultModel )
194194 if err != nil {
195195 return nil , err
196196 }
@@ -310,17 +310,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
310310 localTools [strings .ToLower (tool .Parameters .Name )] = tool
311311 }
312312
313- return linkAll (ctx , cache , prg , base , targetTools , localTools )
313+ return linkAll (ctx , cache , prg , base , targetTools , localTools , defaultModel )
314314}
315315
316- func linkAll (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tools []types.Tool , localTools types.ToolSet ) (result []types.Tool , _ error ) {
316+ func linkAll (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tools []types.Tool , localTools types.ToolSet , defaultModel string ) (result []types.Tool , _ error ) {
317317 localToolsMapping := make (map [string ]string , len (tools ))
318318 for _ , localTool := range localTools {
319319 localToolsMapping [strings .ToLower (localTool .Parameters .Name )] = localTool .ID
320320 }
321321
322322 for _ , tool := range tools {
323- tool , err := link (ctx , cache , prg , base , tool , localTools , localToolsMapping )
323+ tool , err := link (ctx , cache , prg , base , tool , localTools , localToolsMapping , defaultModel )
324324 if err != nil {
325325 return nil , err
326326 }
@@ -329,7 +329,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base
329329 return
330330}
331331
332- func link (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tool types.Tool , localTools types.ToolSet , localToolsMapping map [string ]string ) (types.Tool , error ) {
332+ func link (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , tool types.Tool , localTools types.ToolSet , localToolsMapping map [string ]string , defaultModel string ) (types.Tool , error ) {
333333 if existing , ok := prg .ToolSet [tool .ID ]; ok {
334334 return existing , nil
335335 }
@@ -354,7 +354,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
354354 linkedTool = existing
355355 } else {
356356 var err error
357- linkedTool , err = link (ctx , cache , prg , base , localTool , localTools , localToolsMapping )
357+ linkedTool , err = link (ctx , cache , prg , base , localTool , localTools , localToolsMapping , defaultModel )
358358 if err != nil {
359359 return types.Tool {}, fmt .Errorf ("failed linking %s at %s: %w" , targetToolName , base , err )
360360 }
@@ -364,7 +364,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
364364 toolNames [targetToolName ] = struct {}{}
365365 } else {
366366 toolName , subTool := types .SplitToolRef (targetToolName )
367- resolvedTools , err := resolve (ctx , cache , prg , base , toolName , subTool )
367+ resolvedTools , err := resolve (ctx , cache , prg , base , toolName , subTool , defaultModel )
368368 if err != nil {
369369 return types.Tool {}, fmt .Errorf ("failed resolving %s from %s: %w" , targetToolName , base , err )
370370 }
@@ -376,6 +376,10 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
376376
377377 tool .LocalTools = localToolsMapping
378378
379+ if tool .ModelName == "" {
380+ tool .ModelName = defaultModel
381+ }
382+
379383 tool = builtin .SetDefaults (tool )
380384 prg .ToolSet [tool .ID ] = tool
381385
@@ -405,7 +409,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
405409 Path : locationPath ,
406410 Name : locationName ,
407411 Location : opt .Location ,
408- }, subToolName )
412+ }, subToolName , opt . DefaultModel )
409413 if err != nil {
410414 return types.Program {}, err
411415 }
@@ -414,20 +418,26 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
414418}
415419
416420type Options struct {
417- Cache * cache.Client
418- Location string
421+ Cache * cache.Client
422+ Location string
423+ DefaultModel string
419424}
420425
421426func complete (opts ... Options ) (result Options ) {
422427 for _ , opt := range opts {
423428 result .Cache = types .FirstSet (opt .Cache , result .Cache )
424429 result .Location = types .FirstSet (opt .Location , result .Location )
430+ result .DefaultModel = types .FirstSet (opt .DefaultModel , result .DefaultModel )
425431 }
426432
427433 if result .Location == "" {
428434 result .Location = "inline"
429435 }
430436
437+ if result .DefaultModel == "" {
438+ result .DefaultModel = builtin .GetDefaultModel ()
439+ }
440+
431441 return
432442}
433443
@@ -451,17 +461,17 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty
451461 Name : name ,
452462 ToolSet : types.ToolSet {},
453463 }
454- tools , err := resolve (ctx , opt .Cache , & prg , & source {}, name , subToolName )
464+ tools , err := resolve (ctx , opt .Cache , & prg , & source {}, name , subToolName , opt . DefaultModel )
455465 if err != nil {
456466 return types.Program {}, err
457467 }
458468 prg .EntryToolID = tools [0 ].ID
459469 return prg , nil
460470}
461471
462- func resolve (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , name , subTool string ) ([]types.Tool , error ) {
472+ func resolve (ctx context.Context , cache * cache.Client , prg * types.Program , base * source , name , subTool , defaultModel string ) ([]types.Tool , error ) {
463473 if subTool == "" {
464- t , ok := builtin .Builtin (name )
474+ t , ok := builtin .BuiltinWithDefaultModel (name , defaultModel )
465475 if ok {
466476 prg .ToolSet [t .ID ] = t
467477 return []types.Tool {t }, nil
@@ -473,7 +483,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base
473483 return nil , err
474484 }
475485
476- result , err := readTool (ctx , cache , prg , s , subTool )
486+ result , err := readTool (ctx , cache , prg , s , subTool , defaultModel )
477487 if err != nil {
478488 return nil , err
479489 }
0 commit comments