1
1
package openai
2
2
3
3
import (
4
+ "encoding/json"
5
+
4
6
openai "github.com/gptscript-ai/chat-completion-client"
7
+ "github.com/gptscript-ai/gptscript/pkg/types"
8
+ "github.com/pkoukk/tiktoken-go"
9
+ tiktoken_loader "github.com/pkoukk/tiktoken-go-loader"
5
10
)
6
11
12
+ func init () {
13
+ tiktoken .SetBpeLoader (tiktoken_loader .NewOfflineLoader ())
14
+ }
15
+
7
16
const DefaultMaxTokens = 128_000
8
17
9
18
func decreaseTenPercent (maxTokens int ) int {
@@ -12,22 +21,26 @@ func decreaseTenPercent(maxTokens int) int {
12
21
}
13
22
14
23
func getBudget (maxTokens int ) int {
15
- if maxTokens = = 0 {
24
+ if maxTokens < = 0 {
16
25
return DefaultMaxTokens
17
26
}
18
27
return maxTokens
19
28
}
20
29
21
- func dropMessagesOverCount (maxTokens int , msgs []openai.ChatCompletionMessage ) (result []openai.ChatCompletionMessage ) {
30
+ func dropMessagesOverCount (maxTokens , toolTokenCount int , msgs []openai.ChatCompletionMessage ) (result []openai.ChatCompletionMessage , err error ) {
22
31
var (
23
32
lastSystem int
24
33
withinBudget int
25
- budget = getBudget (maxTokens )
34
+ budget = getBudget (maxTokens ) - toolTokenCount
26
35
)
27
36
28
37
for i , msg := range msgs {
29
38
if msg .Role == openai .ChatMessageRoleSystem {
30
- budget -= countMessage (msg )
39
+ count , err := countMessage (msg )
40
+ if err != nil {
41
+ return nil , err
42
+ }
43
+ budget -= count
31
44
lastSystem = i
32
45
result = append (result , msg )
33
46
} else {
@@ -37,7 +50,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
37
50
38
51
for i := len (msgs ) - 1 ; i > lastSystem ; i -- {
39
52
withinBudget = i
40
- budget -= countMessage (msgs [i ])
53
+ count , err := countMessage (msgs [i ])
54
+ if err != nil {
55
+ return nil , err
56
+ }
57
+ budget -= count
41
58
if budget <= 0 {
42
59
break
43
60
}
@@ -54,22 +71,42 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
54
71
if withinBudget == len (msgs )- 1 {
55
72
// We are going to drop all non system messages, which seems useless, so just return them
56
73
// all and let it fail
57
- return msgs
74
+ return msgs , nil
58
75
}
59
76
60
- return append (result , msgs [withinBudget :]... )
77
+ return append (result , msgs [withinBudget :]... ), nil
61
78
}
62
79
63
- func countMessage (msg openai.ChatCompletionMessage ) (count int ) {
64
- count += len (msg .Role )
65
- count += len (msg .Content )
80
+ func countMessage (msg openai.ChatCompletionMessage ) (int , error ) {
81
+ encoding , err := tiktoken .GetEncoding ("o200k_base" )
82
+ if err != nil {
83
+ return 0 , err
84
+ }
85
+
86
+ count := len (encoding .Encode (msg .Role , nil , nil ))
87
+ count += len (encoding .Encode (msg .Content , nil , nil ))
66
88
for _ , content := range msg .MultiContent {
67
- count += len (content .Text )
89
+ count += len (encoding . Encode ( content .Text , nil , nil ) )
68
90
}
69
91
for _ , tool := range msg .ToolCalls {
70
- count += len (tool .Function .Name )
71
- count += len (tool .Function .Arguments )
92
+ count += len (encoding . Encode ( tool .Function .Name , nil , nil ) )
93
+ count += len (encoding . Encode ( tool .Function .Arguments , nil , nil ) )
72
94
}
73
- count += len (msg .ToolCallID )
74
- return count / 3
95
+ count += len (encoding .Encode (msg .ToolCallID , nil , nil ))
96
+
97
+ return count , nil
98
+ }
99
+
100
+ func countTools (tools []types.ChatCompletionTool ) (int , error ) {
101
+ encoding , err := tiktoken .GetEncoding ("o200k_base" )
102
+ if err != nil {
103
+ return 0 , err
104
+ }
105
+
106
+ toolJSON , err := json .Marshal (tools )
107
+ if err != nil {
108
+ return 0 , err
109
+ }
110
+
111
+ return len (encoding .Encode (string (toolJSON ), nil , nil )), nil
75
112
}
0 commit comments