11package openai
22
33import (
4+ "encoding/json"
5+
46 openai "github.com/gptscript-ai/chat-completion-client"
7+ "github.com/gptscript-ai/gptscript/pkg/types"
8+ "github.com/pkoukk/tiktoken-go"
9+ tiktoken_loader "github.com/pkoukk/tiktoken-go-loader"
510)
611
712const DefaultMaxTokens = 128_000
@@ -12,22 +17,26 @@ func decreaseTenPercent(maxTokens int) int {
1217}
1318
1419func getBudget (maxTokens int ) int {
15- if maxTokens = = 0 {
20+ if maxTokens < = 0 {
1621 return DefaultMaxTokens
1722 }
1823 return maxTokens
1924}
2025
21- func dropMessagesOverCount (maxTokens int , msgs []openai.ChatCompletionMessage ) (result []openai.ChatCompletionMessage ) {
26+ func dropMessagesOverCount (maxTokens , toolTokenCount int , msgs []openai.ChatCompletionMessage ) (result []openai.ChatCompletionMessage , err error ) {
2227 var (
2328 lastSystem int
2429 withinBudget int
25- budget = getBudget (maxTokens )
30+ budget = getBudget (maxTokens ) - toolTokenCount
2631 )
2732
2833 for i , msg := range msgs {
2934 if msg .Role == openai .ChatMessageRoleSystem {
30- budget -= countMessage (msg )
35+ count , err := countMessage (msg )
36+ if err != nil {
37+ return nil , err
38+ }
39+ budget -= count
3140 lastSystem = i
3241 result = append (result , msg )
3342 } else {
@@ -37,7 +46,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
3746
3847 for i := len (msgs ) - 1 ; i > lastSystem ; i -- {
3948 withinBudget = i
40- budget -= countMessage (msgs [i ])
49+ count , err := countMessage (msgs [i ])
50+ if err != nil {
51+ return nil , err
52+ }
53+ budget -= count
4154 if budget <= 0 {
4255 break
4356 }
@@ -54,22 +67,44 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
5467 if withinBudget == len (msgs )- 1 {
5568 // We are going to drop all non system messages, which seems useless, so just return them
5669 // all and let it fail
57- return msgs
70+ return msgs , nil
5871 }
5972
60- return append (result , msgs [withinBudget :]... )
73+ return append (result , msgs [withinBudget :]... ), nil
6174}
6275
63- func countMessage (msg openai.ChatCompletionMessage ) (count int ) {
64- count += len (msg .Role )
65- count += len (msg .Content )
76+ func countMessage (msg openai.ChatCompletionMessage ) (int , error ) {
77+ tiktoken .SetBpeLoader (tiktoken_loader .NewOfflineLoader ())
78+ encoding , err := tiktoken .GetEncoding ("o200k_base" )
79+ if err != nil {
80+ return 0 , err
81+ }
82+
83+ count := len (encoding .Encode (msg .Role , nil , nil ))
84+ count += len (encoding .Encode (msg .Content , nil , nil ))
6685 for _ , content := range msg .MultiContent {
67- count += len (content .Text )
86+ count += len (encoding . Encode ( content .Text , nil , nil ) )
6887 }
6988 for _ , tool := range msg .ToolCalls {
70- count += len (tool .Function .Name )
71- count += len (tool .Function .Arguments )
89+ count += len (encoding . Encode ( tool .Function .Name , nil , nil ) )
90+ count += len (encoding . Encode ( tool .Function .Arguments , nil , nil ) )
7291 }
73- count += len (msg .ToolCallID )
74- return count / 3
92+ count += len (encoding .Encode (msg .ToolCallID , nil , nil ))
93+
94+ return count , nil
95+ }
96+
97+ func countTools (tools []types.ChatCompletionTool ) (int , error ) {
98+ tiktoken .SetBpeLoader (tiktoken_loader .NewOfflineLoader ())
99+ encoding , err := tiktoken .GetEncoding ("o200k_base" )
100+ if err != nil {
101+ return 0 , err
102+ }
103+
104+ toolJSON , err := json .Marshal (tools )
105+ if err != nil {
106+ return 0 , err
107+ }
108+
109+ return len (encoding .Encode (string (toolJSON ), nil , nil )), nil
75110}
0 commit comments