@@ -240,7 +240,7 @@ func toToolCall(call types.CompletionToolCall) openai.ToolCall {
240240 }
241241}
242242
243- func toMessages (request types.CompletionRequest , compat bool ) (result []openai.ChatCompletionMessage , err error ) {
243+ func toMessages (request types.CompletionRequest , compat , useO1Model bool ) (result []openai.ChatCompletionMessage , err error ) {
244244 var (
245245 systemPrompts []string
246246 msgs []types.CompletionMessage
@@ -259,8 +259,12 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
259259 }
260260
261261 if len (systemPrompts ) > 0 {
262+ role := types .CompletionMessageRoleTypeSystem
263+ if useO1Model {
264+ role = types .CompletionMessageRoleTypeDeveloper
265+ }
262266 msgs = slices .Insert (msgs , 0 , types.CompletionMessage {
263- Role : types . CompletionMessageRoleTypeSystem ,
267+ Role : role ,
264268 Content : types .Text (strings .Join (systemPrompts , "\n " )),
265269 })
266270 }
@@ -306,9 +310,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
306310 return
307311}
308312
309- func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , env []string , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
313+ func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , envs []string , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
310314 if err := c .ValidAuth (); err != nil {
311- if err := c .RetrieveAPIKey (ctx , env ); err != nil {
315+ if err := c .RetrieveAPIKey (ctx , envs ); err != nil {
312316 return nil , err
313317 }
314318 }
@@ -317,7 +321,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
317321 messageRequest .Model = c .defaultModel
318322 }
319323
320- msgs , err := toMessages (messageRequest , ! c .setSeed )
324+ useO1Model := isO1Model (messageRequest .Model , envs )
325+
326+ msgs , err := toMessages (messageRequest , ! c .setSeed , useO1Model )
321327 if err != nil {
322328 return nil , err
323329 }
@@ -348,10 +354,13 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
348354 MaxTokens : messageRequest .MaxTokens ,
349355 }
350356
351- if messageRequest .Temperature == nil {
352- request .Temperature = new (float32 )
353- } else {
354- request .Temperature = messageRequest .Temperature
357+ // openai O1 doesn't support setting temperature
358+ if ! useO1Model {
359+ if messageRequest .Temperature == nil {
360+ messageRequest .Temperature = new (float32 )
361+ } else {
362+ request .Temperature = messageRequest .Temperature
363+ }
355364 }
356365
357366 if messageRequest .JSONResponse {
@@ -404,15 +413,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
404413 if err != nil {
405414 return nil , err
406415 } else if ! ok {
407- result , err = c .call (ctx , request , id , env , status )
416+ result , err = c .call (ctx , request , id , envs , status )
408417
409418 // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
410419 var apiError * openai.APIError
411420 if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
412421 // Decrease maxTokens by 10% to make garbage collection more aggressive.
413422 // The retry loop will further decrease maxTokens if needed.
414423 maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
415- result , err = c .contextLimitRetryLoop (ctx , request , id , env , maxTokens , status )
424+ result , err = c .contextLimitRetryLoop (ctx , request , id , envs , maxTokens , status )
416425 }
417426 if err != nil {
418427 return nil , err
@@ -446,6 +455,22 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
446455 return & result , nil
447456}
448457
458+ func isO1Model (model string , envs []string ) bool {
459+ if model == "o1" {
460+ return true
461+ }
462+
463+ o1Model := false
464+ for _ , env := range envs {
465+ k , v , _ := strings .Cut (env , "=" )
466+ if k == "OPENAI_MODEL_NAME" && v == "o1" {
467+ o1Model = true
468+ }
469+ }
470+
471+ return o1Model
472+ }
473+
449474func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , env []string , maxTokens int , status chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
450475 var (
451476 response types.CompletionMessage
@@ -545,9 +570,14 @@ func override(left, right string) string {
545570 return left
546571}
547572
548- func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , env []string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
573+ func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , envs []string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
549574 streamResponse := os .Getenv ("GPTSCRIPT_INTERNAL_OPENAI_STREAMING" ) != "false"
550575
576+ useO1Model := isO1Model (request .Model , envs )
577+ if useO1Model {
578+ streamResponse = false
579+ }
580+
551581 partial <- types.CompletionStatus {
552582 CompletionID : transactionID ,
553583 PartialResponse : & types.CompletionMessage {
@@ -567,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
567597 },
568598 }
569599 )
570- for _ , e := range env {
600+ for _ , e := range envs {
571601 if strings .HasPrefix (e , "GPTSCRIPT_MODEL_PROVIDER_" ) {
572602 modelProviderEnv = append (modelProviderEnv , e )
573603 } else if strings .HasPrefix (e , "GPTSCRIPT_DISABLE_RETRIES" ) {
0 commit comments