1
1
package openai
2
2
3
3
import (
4
+ "encoding/json"
5
+
4
6
openai "github.com/gptscript-ai/chat-completion-client"
7
+ "github.com/gptscript-ai/gptscript/pkg/types"
8
+ "github.com/pkoukk/tiktoken-go"
9
+ tiktoken_loader "github.com/pkoukk/tiktoken-go-loader"
5
10
)
6
11
7
12
const DefaultMaxTokens = 128_000
@@ -12,22 +17,26 @@ func decreaseTenPercent(maxTokens int) int {
12
17
}
13
18
14
19
func getBudget (maxTokens int ) int {
15
- if maxTokens = = 0 {
20
+ if maxTokens < = 0 {
16
21
return DefaultMaxTokens
17
22
}
18
23
return maxTokens
19
24
}
20
25
21
- func dropMessagesOverCount (maxTokens int , msgs []openai.ChatCompletionMessage ) (result []openai.ChatCompletionMessage ) {
26
+ func dropMessagesOverCount (maxTokens , toolTokenCount int , msgs []openai.ChatCompletionMessage ) (result []openai.ChatCompletionMessage , err error ) {
22
27
var (
23
28
lastSystem int
24
29
withinBudget int
25
- budget = getBudget (maxTokens )
30
+ budget = getBudget (maxTokens ) - toolTokenCount
26
31
)
27
32
28
33
for i , msg := range msgs {
29
34
if msg .Role == openai .ChatMessageRoleSystem {
30
- budget -= countMessage (msg )
35
+ count , err := countMessage (msg )
36
+ if err != nil {
37
+ return nil , err
38
+ }
39
+ budget -= count
31
40
lastSystem = i
32
41
result = append (result , msg )
33
42
} else {
@@ -37,7 +46,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
37
46
38
47
for i := len (msgs ) - 1 ; i > lastSystem ; i -- {
39
48
withinBudget = i
40
- budget -= countMessage (msgs [i ])
49
+ count , err := countMessage (msgs [i ])
50
+ if err != nil {
51
+ return nil , err
52
+ }
53
+ budget -= count
41
54
if budget <= 0 {
42
55
break
43
56
}
@@ -54,22 +67,44 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
54
67
if withinBudget == len (msgs )- 1 {
55
68
// We are going to drop all non system messages, which seems useless, so just return them
56
69
// all and let it fail
57
- return msgs
70
+ return msgs , nil
58
71
}
59
72
60
- return append (result , msgs [withinBudget :]... )
73
+ return append (result , msgs [withinBudget :]... ), nil
61
74
}
62
75
63
- func countMessage (msg openai.ChatCompletionMessage ) (count int ) {
64
- count += len (msg .Role )
65
- count += len (msg .Content )
76
+ func countMessage (msg openai.ChatCompletionMessage ) (int , error ) {
77
+ tiktoken .SetBpeLoader (tiktoken_loader .NewOfflineLoader ())
78
+ encoding , err := tiktoken .GetEncoding ("o200k_base" )
79
+ if err != nil {
80
+ return 0 , err
81
+ }
82
+
83
+ count := len (encoding .Encode (msg .Role , nil , nil ))
84
+ count += len (encoding .Encode (msg .Content , nil , nil ))
66
85
for _ , content := range msg .MultiContent {
67
- count += len (content .Text )
86
+ count += len (encoding . Encode ( content .Text , nil , nil ) )
68
87
}
69
88
for _ , tool := range msg .ToolCalls {
70
- count += len (tool .Function .Name )
71
- count += len (tool .Function .Arguments )
89
+ count += len (encoding . Encode ( tool .Function .Name , nil , nil ) )
90
+ count += len (encoding . Encode ( tool .Function .Arguments , nil , nil ) )
72
91
}
73
- count += len (msg .ToolCallID )
74
- return count / 3
92
+ count += len (encoding .Encode (msg .ToolCallID , nil , nil ))
93
+
94
+ return count , nil
95
+ }
96
+
97
+ func countTools (tools []types.ChatCompletionTool ) (int , error ) {
98
+ tiktoken .SetBpeLoader (tiktoken_loader .NewOfflineLoader ())
99
+ encoding , err := tiktoken .GetEncoding ("o200k_base" )
100
+ if err != nil {
101
+ return 0 , err
102
+ }
103
+
104
+ toolJSON , err := json .Marshal (tools )
105
+ if err != nil {
106
+ return 0 , err
107
+ }
108
+
109
+ return len (encoding .Encode (string (toolJSON ), nil , nil )), nil
75
110
}
0 commit comments