Skip to content

Commit c519c63

Browse files
authored
fix: estimate tokens using tiktoken (#959)
Signed-off-by: Grant Linville <[email protected]>
1 parent 9abfd87 commit c519c63

File tree

4 files changed

+83
-23
lines changed

4 files changed

+83
-23
lines changed

go.mod

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ require (
2323
github.com/hexops/valast v1.4.4
2424
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
2525
github.com/mholt/archives v0.1.0
26+
github.com/pkoukk/tiktoken-go v0.1.7
27+
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699
2628
github.com/rs/cors v1.11.0
2729
github.com/samber/lo v1.38.1
2830
github.com/sirupsen/logrus v1.9.3
@@ -62,7 +64,7 @@ require (
6264
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
6365
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
6466
github.com/davecgh/go-spew v1.1.1 // indirect
65-
github.com/dlclark/regexp2 v1.4.0 // indirect
67+
github.com/dlclark/regexp2 v1.10.0 // indirect
6668
github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect
6769
github.com/emirpasic/gods v1.18.1 // indirect
6870
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect

go.sum

+6-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG
108108
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
109109
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
110110
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
111-
github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E=
112-
github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
111+
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
112+
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
113113
github.com/docker/cli v26.0.0+incompatible h1:90BKrx1a1HKYpSnnBFR6AgDq/FqkHxwlUyzJVPxD30I=
114114
github.com/docker/cli v26.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
115115
github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo=
@@ -316,6 +316,10 @@ github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFz
316316
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
317317
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
318318
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
319+
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
320+
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
321+
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 h1:Sp8yiuxsitkmCfEvUnmNf8wzuZwlGNkRjI2yF0C3QUQ=
322+
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
319323
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
320324
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
321325
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=

pkg/openai/client.go

+22-5
Original file line numberDiff line numberDiff line change
@@ -349,16 +349,29 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
349349
return nil, err
350350
}
351351

352+
toolTokenCount, err := countTools(messageRequest.Tools)
353+
if err != nil {
354+
return nil, err
355+
}
356+
352357
if messageRequest.Chat {
353358
// Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
354359
lastMessage := msgs[len(msgs)-1]
355-
if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(float64(getBudget(messageRequest.MaxTokens))*0.8) {
360+
lastMessageCount, err := countMessage(lastMessage)
361+
if err != nil {
362+
return nil, err
363+
}
364+
365+
if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && lastMessageCount+toolTokenCount > int(float64(getBudget(messageRequest.MaxTokens))*0.8) {
356366
// We need to update it in the msgs slice for right now and in the messageRequest for future calls.
357367
msgs[len(msgs)-1].Content = TooLongMessage
358368
messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage)
359369
}
360370

361-
msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs)
371+
msgs, err = dropMessagesOverCount(messageRequest.MaxTokens, toolTokenCount, msgs)
372+
if err != nil {
373+
return nil, err
374+
}
362375
}
363376

364377
if len(msgs) == 0 {
@@ -439,7 +452,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
439452
// Decrease maxTokens by 10% to make garbage collection more aggressive.
440453
// The retry loop will further decrease maxTokens if needed.
441454
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
442-
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
455+
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, toolTokenCount, status)
443456
}
444457
if err != nil {
445458
return nil, err
@@ -473,15 +486,19 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
473486
return &result, nil
474487
}
475488

476-
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
489+
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, toolTokenCount int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
477490
var (
478491
response types.CompletionMessage
479492
err error
480493
)
481494

482495
for range 10 { // maximum 10 tries
483496
// Try to drop older messages again, with a decreased max tokens.
484-
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
497+
request.Messages, err = dropMessagesOverCount(maxTokens, toolTokenCount, request.Messages)
498+
if err != nil {
499+
return types.CompletionMessage{}, err
500+
}
501+
485502
response, err = c.call(ctx, request, id, env, status)
486503
if err == nil {
487504
return response, nil

pkg/openai/count.go

+52-15
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
package openai
22

33
import (
4+
"encoding/json"
5+
46
openai "github.com/gptscript-ai/chat-completion-client"
7+
"github.com/gptscript-ai/gptscript/pkg/types"
8+
"github.com/pkoukk/tiktoken-go"
9+
tiktoken_loader "github.com/pkoukk/tiktoken-go-loader"
510
)
611

12+
func init() {
13+
tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader())
14+
}
15+
716
const DefaultMaxTokens = 128_000
817

918
func decreaseTenPercent(maxTokens int) int {
@@ -12,22 +21,26 @@ func decreaseTenPercent(maxTokens int) int {
1221
}
1322

1423
func getBudget(maxTokens int) int {
15-
if maxTokens == 0 {
24+
if maxTokens <= 0 {
1625
return DefaultMaxTokens
1726
}
1827
return maxTokens
1928
}
2029

21-
func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) {
30+
func dropMessagesOverCount(maxTokens, toolTokenCount int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage, err error) {
2231
var (
2332
lastSystem int
2433
withinBudget int
25-
budget = getBudget(maxTokens)
34+
budget = getBudget(maxTokens) - toolTokenCount
2635
)
2736

2837
for i, msg := range msgs {
2938
if msg.Role == openai.ChatMessageRoleSystem {
30-
budget -= countMessage(msg)
39+
count, err := countMessage(msg)
40+
if err != nil {
41+
return nil, err
42+
}
43+
budget -= count
3144
lastSystem = i
3245
result = append(result, msg)
3346
} else {
@@ -37,7 +50,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
3750

3851
for i := len(msgs) - 1; i > lastSystem; i-- {
3952
withinBudget = i
40-
budget -= countMessage(msgs[i])
53+
count, err := countMessage(msgs[i])
54+
if err != nil {
55+
return nil, err
56+
}
57+
budget -= count
4158
if budget <= 0 {
4259
break
4360
}
@@ -54,22 +71,42 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
5471
if withinBudget == len(msgs)-1 {
5572
// We are going to drop all non system messages, which seems useless, so just return them
5673
// all and let it fail
57-
return msgs
74+
return msgs, nil
5875
}
5976

60-
return append(result, msgs[withinBudget:]...)
77+
return append(result, msgs[withinBudget:]...), nil
6178
}
6279

63-
func countMessage(msg openai.ChatCompletionMessage) (count int) {
64-
count += len(msg.Role)
65-
count += len(msg.Content)
80+
func countMessage(msg openai.ChatCompletionMessage) (int, error) {
81+
encoding, err := tiktoken.GetEncoding("o200k_base")
82+
if err != nil {
83+
return 0, err
84+
}
85+
86+
count := len(encoding.Encode(msg.Role, nil, nil))
87+
count += len(encoding.Encode(msg.Content, nil, nil))
6688
for _, content := range msg.MultiContent {
67-
count += len(content.Text)
89+
count += len(encoding.Encode(content.Text, nil, nil))
6890
}
6991
for _, tool := range msg.ToolCalls {
70-
count += len(tool.Function.Name)
71-
count += len(tool.Function.Arguments)
92+
count += len(encoding.Encode(tool.Function.Name, nil, nil))
93+
count += len(encoding.Encode(tool.Function.Arguments, nil, nil))
7294
}
73-
count += len(msg.ToolCallID)
74-
return count / 3
95+
count += len(encoding.Encode(msg.ToolCallID, nil, nil))
96+
97+
return count, nil
98+
}
99+
100+
func countTools(tools []types.ChatCompletionTool) (int, error) {
101+
encoding, err := tiktoken.GetEncoding("o200k_base")
102+
if err != nil {
103+
return 0, err
104+
}
105+
106+
toolJSON, err := json.Marshal(tools)
107+
if err != nil {
108+
return 0, err
109+
}
110+
111+
return len(encoding.Encode(string(toolJSON), nil, nil)), nil
75112
}

0 commit comments

Comments
 (0)