@@ -13,7 +13,6 @@ import (
1313
1414 openai "github.com/gptscript-ai/chat-completion-client"
1515 "github.com/gptscript-ai/gptscript/pkg/cache"
16- gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1716 "github.com/gptscript-ai/gptscript/pkg/counter"
1817 "github.com/gptscript-ai/gptscript/pkg/credentials"
1918 "github.com/gptscript-ai/gptscript/pkg/hash"
@@ -305,7 +304,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
305304
306305func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
307306 if err := c .ValidAuth (); err != nil {
308- if err := c .RetrieveAPIKey (ctx ); err != nil {
307+ if err := c .RetrieveAPIKey (ctx , messageRequest . Env ); err != nil {
309308 return nil , err
310309 }
311310 }
@@ -401,15 +400,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
401400 if err != nil {
402401 return nil , err
403402 } else if ! ok {
404- result , err = c .call (ctx , request , id , status )
403+ result , err = c .call (ctx , request , id , messageRequest . Env , status )
405404
406405 // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
407406 var apiError * openai.APIError
408407 if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
409408 // Decrease maxTokens by 10% to make garbage collection more aggressive.
410409 // The retry loop will further decrease maxTokens if needed.
411410 maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
412- result , err = c .contextLimitRetryLoop (ctx , request , id , maxTokens , status )
411+ result , err = c .contextLimitRetryLoop (ctx , request , id , messageRequest . Env , maxTokens , status )
413412 }
414413 if err != nil {
415414 return nil , err
@@ -443,7 +442,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
443442 return & result , nil
444443}
445444
446- func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , maxTokens int , status chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
445+ func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , env [] string , maxTokens int , status chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
447446 var (
448447 response types.CompletionMessage
449448 err error
@@ -452,7 +451,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC
452451 for range 10 { // maximum 10 tries
453452 // Try to drop older messages again, with a decreased max tokens.
454453 request .Messages = dropMessagesOverCount (maxTokens , request .Messages )
455- response , err = c .call (ctx , request , id , status )
454+ response , err = c .call (ctx , request , id , env , status )
456455 if err == nil {
457456 return response , nil
458457 }
@@ -542,7 +541,7 @@ func override(left, right string) string {
542541 return left
543542}
544543
545- func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
544+ func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , env [] string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
546545 streamResponse := os .Getenv ("GPTSCRIPT_INTERNAL_OPENAI_STREAMING" ) != "false"
547546
548547 partial <- types.CompletionStatus {
@@ -553,11 +552,27 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
553552 },
554553 }
555554
555+ var (
556+ headers map [string ]string
557+ modelProviderEnv []string
558+ )
559+ for _ , e := range env {
560+ if strings .HasPrefix (e , "GPTSCRIPT_MODEL_PROVIDER_" ) {
561+ modelProviderEnv = append (modelProviderEnv , e )
562+ }
563+ }
564+
565+ if len (modelProviderEnv ) > 0 {
566+ headers = map [string ]string {
567+ "X-GPTScript-Env" : strings .Join (modelProviderEnv , "," ),
568+ }
569+ }
570+
556571 slog .Debug ("calling openai" , "message" , request .Messages )
557572
558573 if ! streamResponse {
559574 request .StreamOptions = nil
560- resp , err := c .c .CreateChatCompletion (ctx , request )
575+ resp , err := c .c .CreateChatCompletion (ctx , request , headers )
561576 if err != nil {
562577 return types.CompletionMessage {}, err
563578 }
@@ -582,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
582597 }), nil
583598 }
584599
585- stream , err := c .c .CreateChatCompletionStream (ctx , request )
600+ stream , err := c .c .CreateChatCompletionStream (ctx , request , headers )
586601 if err != nil {
587602 return types.CompletionMessage {}, err
588603 }
@@ -614,8 +629,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
614629 }
615630}
616631
617- func (c * Client ) RetrieveAPIKey (ctx context.Context ) error {
618- k , err := prompt .GetModelProviderCredential (ctx , c .credStore , BuiltinCredName , "OPENAI_API_KEY" , "Please provide your OpenAI API key:" , gcontext . GetEnv ( ctx ) )
632+ func (c * Client ) RetrieveAPIKey (ctx context.Context , env [] string ) error {
633+ k , err := prompt .GetModelProviderCredential (ctx , c .credStore , BuiltinCredName , "OPENAI_API_KEY" , "Please provide your OpenAI API key:" , env )
619634 if err != nil {
620635 return err
621636 }
0 commit comments