Skip to content

Commit 378eb46

Browse files
committed
fix: count tokens from tool definitions when adjusting for context window
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent 7ee5c80 commit 378eb46

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

pkg/openai/client.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,12 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
331331
messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage)
332332
}
333333

334-
msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs)
334+
toolsCount, err := countChatCompletionTools(messageRequest.Tools)
335+
if err != nil {
336+
return nil, err
337+
}
338+
339+
msgs = dropMessagesOverCount(messageRequest.MaxTokens-toolsCount, msgs)
335340
}
336341

337342
if len(msgs) == 0 {
@@ -447,14 +452,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
447452
}
448453

449454
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
450-
var (
451-
response types.CompletionMessage
452-
err error
453-
)
455+
toolsCount, err := countOpenAITools(request.Tools)
456+
if err != nil {
457+
return types.CompletionMessage{}, err
458+
}
454459

460+
var response types.CompletionMessage
455461
for range 10 { // maximum 10 tries
456462
// Try to drop older messages again, with a decreased max tokens.
457-
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
463+
request.Messages = dropMessagesOverCount(maxTokens-toolsCount, request.Messages)
458464
response, err = c.call(ctx, request, id, env, status)
459465
if err == nil {
460466
return response, nil

pkg/openai/count.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package openai
22

33
import (
4+
"encoding/json"
5+
46
openai "github.com/gptscript-ai/chat-completion-client"
7+
"github.com/gptscript-ai/gptscript/pkg/types"
58
)
69

710
const DefaultMaxTokens = 128_000
@@ -73,3 +76,29 @@ func countMessage(msg openai.ChatCompletionMessage) (count int) {
7376
count += len(msg.ToolCallID)
7477
return count / 3
7578
}
79+
80+
func countChatCompletionTools(tools []types.ChatCompletionTool) (count int, err error) {
81+
for _, t := range tools {
82+
count += len(t.Function.Name)
83+
count += len(t.Function.Description)
84+
paramsJSON, err := json.Marshal(t.Function.Parameters)
85+
if err != nil {
86+
return 0, err
87+
}
88+
count += len(paramsJSON)
89+
}
90+
return count / 3, nil
91+
}
92+
93+
func countOpenAITools(tools []openai.Tool) (count int, err error) {
94+
for _, t := range tools {
95+
count += len(t.Function.Name)
96+
count += len(t.Function.Description)
97+
paramsJSON, err := json.Marshal(t.Function.Parameters)
98+
if err != nil {
99+
return 0, err
100+
}
101+
count += len(paramsJSON)
102+
}
103+
return count / 3, nil
104+
}

0 commit comments

Comments
 (0)