Skip to content

fix: estimate tokens using tiktoken #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ require (
github.com/hexops/valast v1.4.4
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
github.com/mholt/archives v0.1.0
github.com/pkoukk/tiktoken-go v0.1.7
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699
github.com/rs/cors v1.11.0
github.com/samber/lo v1.38.1
github.com/sirupsen/logrus v1.9.3
Expand Down Expand Up @@ -62,7 +64,7 @@ require (
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
Expand Down
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E=
github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/cli v26.0.0+incompatible h1:90BKrx1a1HKYpSnnBFR6AgDq/FqkHxwlUyzJVPxD30I=
github.com/docker/cli v26.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo=
Expand Down Expand Up @@ -316,6 +316,10 @@ github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFz
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 h1:Sp8yiuxsitkmCfEvUnmNf8wzuZwlGNkRjI2yF0C3QUQ=
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
Expand Down
27 changes: 22 additions & 5 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,29 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return nil, err
}

toolTokenCount, err := countTools(messageRequest.Tools)
if err != nil {
return nil, err
}

if messageRequest.Chat {
// Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
lastMessage := msgs[len(msgs)-1]
if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(float64(getBudget(messageRequest.MaxTokens))*0.8) {
lastMessageCount, err := countMessage(lastMessage)
if err != nil {
return nil, err
}

if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && lastMessageCount+toolTokenCount > int(float64(getBudget(messageRequest.MaxTokens))*0.8) {
// We need to update it in the msgs slice for right now and in the messageRequest for future calls.
msgs[len(msgs)-1].Content = TooLongMessage
messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage)
}

msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs)
msgs, err = dropMessagesOverCount(messageRequest.MaxTokens, toolTokenCount, msgs)
if err != nil {
return nil, err
}
}

if len(msgs) == 0 {
Expand Down Expand Up @@ -439,7 +452,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
// Decrease maxTokens by 10% to make garbage collection more aggressive.
// The retry loop will further decrease maxTokens if needed.
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, toolTokenCount, status)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -473,15 +486,19 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return &result, nil
}

func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, toolTokenCount int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
var (
response types.CompletionMessage
err error
)

for range 10 { // maximum 10 tries
// Try to drop older messages again, with a decreased max tokens.
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
request.Messages, err = dropMessagesOverCount(maxTokens, toolTokenCount, request.Messages)
if err != nil {
return types.CompletionMessage{}, err
}

response, err = c.call(ctx, request, id, env, status)
if err == nil {
return response, nil
Expand Down
67 changes: 52 additions & 15 deletions pkg/openai/count.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package openai

import (
"encoding/json"

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/pkoukk/tiktoken-go"
tiktoken_loader "github.com/pkoukk/tiktoken-go-loader"
)

func init() {
tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader())
}

const DefaultMaxTokens = 128_000

func decreaseTenPercent(maxTokens int) int {
Expand All @@ -12,22 +21,26 @@ func decreaseTenPercent(maxTokens int) int {
}

func getBudget(maxTokens int) int {
if maxTokens == 0 {
if maxTokens <= 0 {
return DefaultMaxTokens
}
return maxTokens
}

func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) {
func dropMessagesOverCount(maxTokens, toolTokenCount int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage, err error) {
var (
lastSystem int
withinBudget int
budget = getBudget(maxTokens)
budget = getBudget(maxTokens) - toolTokenCount
)

for i, msg := range msgs {
if msg.Role == openai.ChatMessageRoleSystem {
budget -= countMessage(msg)
count, err := countMessage(msg)
if err != nil {
return nil, err
}
budget -= count
lastSystem = i
result = append(result, msg)
} else {
Expand All @@ -37,7 +50,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (

for i := len(msgs) - 1; i > lastSystem; i-- {
withinBudget = i
budget -= countMessage(msgs[i])
count, err := countMessage(msgs[i])
if err != nil {
return nil, err
}
budget -= count
if budget <= 0 {
break
}
Expand All @@ -54,22 +71,42 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
if withinBudget == len(msgs)-1 {
// We are going to drop all non system messages, which seems useless, so just return them
// all and let it fail
return msgs
return msgs, nil
}

return append(result, msgs[withinBudget:]...)
return append(result, msgs[withinBudget:]...), nil
}

func countMessage(msg openai.ChatCompletionMessage) (count int) {
count += len(msg.Role)
count += len(msg.Content)
func countMessage(msg openai.ChatCompletionMessage) (int, error) {
encoding, err := tiktoken.GetEncoding("o200k_base")
if err != nil {
return 0, err
}

count := len(encoding.Encode(msg.Role, nil, nil))
count += len(encoding.Encode(msg.Content, nil, nil))
for _, content := range msg.MultiContent {
count += len(content.Text)
count += len(encoding.Encode(content.Text, nil, nil))
}
for _, tool := range msg.ToolCalls {
count += len(tool.Function.Name)
count += len(tool.Function.Arguments)
count += len(encoding.Encode(tool.Function.Name, nil, nil))
count += len(encoding.Encode(tool.Function.Arguments, nil, nil))
}
count += len(msg.ToolCallID)
return count / 3
count += len(encoding.Encode(msg.ToolCallID, nil, nil))

return count, nil
}

func countTools(tools []types.ChatCompletionTool) (int, error) {
encoding, err := tiktoken.GetEncoding("o200k_base")
if err != nil {
return 0, err
}

toolJSON, err := json.Marshal(tools)
if err != nil {
return 0, err
}

return len(encoding.Encode(string(toolJSON), nil, nil)), nil
}