@@ -2,6 +2,7 @@ package openai
22
33import (
44 "context"
5+ "errors"
56 "io"
67 "log/slog"
78 "os"
@@ -24,6 +25,7 @@ import (
2425const (
2526 DefaultModel = openai .GPT4o
2627 BuiltinCredName = "sys.openai"
28+ TooLongMessage = "Error: tool call output is too long"
2729)
2830
2931var (
@@ -317,6 +319,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
317319 }
318320
319321 if messageRequest .Chat {
322+ // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
323+ lastMessage := msgs [len (msgs )- 1 ]
324+ if lastMessage .Role == string (types .CompletionMessageRoleTypeTool ) && countMessage (lastMessage ) > int (float64 (getBudget (messageRequest .MaxTokens ))* 0.8 ) {
325+ // We need to update it in the msgs slice for right now and in the messageRequest for future calls.
326+ msgs [len (msgs )- 1 ].Content = TooLongMessage
327+ messageRequest .Messages [len (messageRequest .Messages )- 1 ].Content = types .Text (TooLongMessage )
328+ }
329+
320330 msgs = dropMessagesOverCount (messageRequest .MaxTokens , msgs )
321331 }
322332
@@ -383,6 +393,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
383393 return nil , err
384394 } else if ! ok {
385395 response , err = c .call (ctx , request , id , status )
396+
397+ // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
398+ var apiError * openai.APIError
399+ if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
400+ // Decrease maxTokens by 10% to make garbage collection more aggressive.
401+ // The retry loop will further decrease maxTokens if needed.
402+ maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
403+ response , err = c .contextLimitRetryLoop (ctx , request , id , maxTokens , status )
404+ }
405+
386406 if err != nil {
387407 return nil , err
388408 }
@@ -421,6 +441,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
421441 return & result , nil
422442}
423443
444+ func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , maxTokens int , status chan <- types.CompletionStatus ) ([]openai.ChatCompletionStreamResponse , error ) {
445+ var (
446+ response []openai.ChatCompletionStreamResponse
447+ err error
448+ )
449+
450+ for range 10 { // maximum 10 tries
451+ // Try to drop older messages again, with a decreased max tokens.
452+ request .Messages = dropMessagesOverCount (maxTokens , request .Messages )
453+ response , err = c .call (ctx , request , id , status )
454+ if err == nil {
455+ return response , nil
456+ }
457+
458+ var apiError * openai.APIError
459+ if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" {
460+ // Decrease maxTokens and try again
461+ maxTokens = decreaseTenPercent (maxTokens )
462+ continue
463+ }
464+ return nil , err
465+ }
466+
467+ return nil , err
468+ }
469+
424470func appendMessage (msg types.CompletionMessage , response openai.ChatCompletionStreamResponse ) types.CompletionMessage {
425471 msg .Usage .CompletionTokens = types .FirstSet (msg .Usage .CompletionTokens , response .Usage .CompletionTokens )
426472 msg .Usage .PromptTokens = types .FirstSet (msg .Usage .PromptTokens , response .Usage .PromptTokens )
0 commit comments