From fd36b2c1d722f393c3f320b5ca91f1e56c6326fd Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 9 Apr 2025 16:39:09 -0400 Subject: [PATCH 1/2] fix: estimate tokens using tiktoken Signed-off-by: Grant Linville --- go.mod | 4 ++- go.sum | 8 ++++-- pkg/openai/client.go | 27 ++++++++++++++---- pkg/openai/count.go | 65 ++++++++++++++++++++++++++++++++++---------- 4 files changed, 81 insertions(+), 23 deletions(-) diff --git a/go.mod b/go.mod index 15f88d5f..f803a3b9 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,8 @@ require ( github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 github.com/mholt/archives v0.1.0 + github.com/pkoukk/tiktoken-go v0.1.7 + github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 github.com/rs/cors v1.11.0 github.com/samber/lo v1.38.1 github.com/sirupsen/logrus v1.9.3 @@ -62,7 +64,7 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dlclark/regexp2 v1.4.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect diff --git a/go.sum b/go.sum index 07d8d500..74341af5 100644 --- a/go.sum +++ b/go.sum @@ -108,8 +108,8 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E= -github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/cli v26.0.0+incompatible h1:90BKrx1a1HKYpSnnBFR6AgDq/FqkHxwlUyzJVPxD30I= github.com/docker/cli v26.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo= @@ -316,6 +316,10 @@ github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFz github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= +github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 h1:Sp8yiuxsitkmCfEvUnmNf8wzuZwlGNkRjI2yF0C3QUQ= +github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 65bc2ae8..7715c657 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -349,16 +349,29 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return nil, err } + toolTokenCount, err := countTools(messageRequest.Tools) + if err != nil { + return nil, err + } + if messageRequest.Chat { // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it. lastMessage := msgs[len(msgs)-1] - if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(float64(getBudget(messageRequest.MaxTokens))*0.8) { + lastMessageCount, err := countMessage(lastMessage) + if err != nil { + return nil, err + } + + if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && lastMessageCount+toolTokenCount > int(float64(getBudget(messageRequest.MaxTokens))*0.8) { // We need to update it in the msgs slice for right now and in the messageRequest for future calls. msgs[len(msgs)-1].Content = TooLongMessage messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) } - msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs) + msgs, err = dropMessagesOverCount(messageRequest.MaxTokens, toolTokenCount, msgs) + if err != nil { + return nil, err + } } if len(msgs) == 0 { @@ -439,7 +452,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques // Decrease maxTokens by 10% to make garbage collection more aggressive. // The retry loop will further decrease maxTokens if needed. maxTokens := decreaseTenPercent(messageRequest.MaxTokens) - result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status) + result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, toolTokenCount, status) } if err != nil { return nil, err @@ -473,7 +486,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return &result, nil } -func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { +func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, toolTokenCount int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { var ( response types.CompletionMessage err error @@ -481,7 +494,11 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC for range 10 { // maximum 10 tries // Try to drop older messages again, with a decreased max tokens. - request.Messages = dropMessagesOverCount(maxTokens, request.Messages) + request.Messages, err = dropMessagesOverCount(maxTokens, toolTokenCount, request.Messages) + if err != nil { + return types.CompletionMessage{}, err + } + response, err = c.call(ctx, request, id, env, status) if err == nil { return response, nil diff --git a/pkg/openai/count.go b/pkg/openai/count.go index ffd902e5..8e755d1c 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -1,7 +1,12 @@ package openai import ( + "encoding/json" + openai "github.com/gptscript-ai/chat-completion-client" + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/pkoukk/tiktoken-go" + tiktoken_loader "github.com/pkoukk/tiktoken-go-loader" ) const DefaultMaxTokens = 128_000 @@ -12,22 +17,26 @@ func decreaseTenPercent(maxTokens int) int { } func getBudget(maxTokens int) int { - if maxTokens == 0 { + if maxTokens <= 0 { return DefaultMaxTokens } return maxTokens } -func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) { +func dropMessagesOverCount(maxTokens, toolTokenCount int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage, err error) { var ( lastSystem int withinBudget int - budget = getBudget(maxTokens) + budget = getBudget(maxTokens) - toolTokenCount ) for i, msg := range msgs { if msg.Role == openai.ChatMessageRoleSystem { - budget -= countMessage(msg) + count, err := countMessage(msg) + if err != nil { + return nil, err + } + budget -= count lastSystem = i result = append(result, msg) } else { @@ -37,7 +46,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) ( for i := len(msgs) - 1; i > lastSystem; i-- { withinBudget = i - budget -= countMessage(msgs[i]) + count, err := countMessage(msgs[i]) + if err != nil { + return nil, err + } + budget -= count if budget <= 0 { break } @@ -54,22 +67,44 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) ( if withinBudget == len(msgs)-1 { // We are going to drop all non system messages, which seems useless, so just return them // all and let it fail - return msgs + return msgs, nil } - return append(result, msgs[withinBudget:]...) + return append(result, msgs[withinBudget:]...), nil } -func countMessage(msg openai.ChatCompletionMessage) (count int) { - count += len(msg.Role) - count += len(msg.Content) +func countMessage(msg openai.ChatCompletionMessage) (int, error) { + tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader()) + encoding, err := tiktoken.GetEncoding("o200k_base") + if err != nil { + return 0, err + } + + count := len(encoding.Encode(msg.Role, nil, nil)) + count += len(encoding.Encode(msg.Content, nil, nil)) for _, content := range msg.MultiContent { - count += len(content.Text) + count += len(encoding.Encode(content.Text, nil, nil)) } for _, tool := range msg.ToolCalls { - count += len(tool.Function.Name) - count += len(tool.Function.Arguments) + count += len(encoding.Encode(tool.Function.Name, nil, nil)) + count += len(encoding.Encode(tool.Function.Arguments, nil, nil)) } - count += len(msg.ToolCallID) - return count / 3 + count += len(encoding.Encode(msg.ToolCallID, nil, nil)) + + return count, nil +} + +func countTools(tools []types.ChatCompletionTool) (int, error) { + tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader()) + encoding, err := tiktoken.GetEncoding("o200k_base") + if err != nil { + return 0, err + } + + toolJSON, err := json.Marshal(tools) + if err != nil { + return 0, err + } + + return len(encoding.Encode(string(toolJSON), nil, nil)), nil } From 7ea03af63ff691f21fa32a4e21f9c539c1c511a0 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 10 Apr 2025 10:23:37 -0400 Subject: [PATCH 2/2] PR feedback Signed-off-by: Grant Linville --- pkg/openai/count.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/openai/count.go b/pkg/openai/count.go index 8e755d1c..d8f2ca36 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -9,6 +9,10 @@ import ( tiktoken_loader "github.com/pkoukk/tiktoken-go-loader" ) +func init() { + tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader()) +} + const DefaultMaxTokens = 128_000 func decreaseTenPercent(maxTokens int) int { @@ -74,7 +78,6 @@ func dropMessagesOverCount(maxTokens, toolTokenCount int, msgs []openai.ChatComp } func countMessage(msg openai.ChatCompletionMessage) (int, error) { - tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader()) encoding, err := tiktoken.GetEncoding("o200k_base") if err != nil { return 0, err @@ -95,7 +98,6 @@ func countMessage(msg openai.ChatCompletionMessage) (int, error) { } func countTools(tools []types.ChatCompletionTool) (int, error) { - tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader()) encoding, err := tiktoken.GetEncoding("o200k_base") if err != nil { return 0, err