Skip to content

Commit

Permalink
Fix mistral count tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
bkiat1123 committed Jan 18, 2024
1 parent febe594 commit 5341814
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions llms/providers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mistralai.models.chat_completion import ChatMessage



from ..results.result import AsyncStreamResult, Result, StreamResult
from .base_provider import BaseProvider

Expand Down Expand Up @@ -37,26 +36,19 @@ def __init__(
async_client_kwargs = {}
self.async_client = MistralAsyncClient(api_key=api_key, **async_client_kwargs)

def count_tokens(self, content: Union[str, List[dict]]) -> int:
def count_tokens(self, content: str | List[ChatMessage]) -> int:
# TODO: update after Mistarl support count token in their SDK
# use gpt 3.5 turbo for estimation now
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
if isinstance(content, list):
# When field name is present, ChatGPT will ignore the role token.
# Adopted from OpenAI cookbook
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
# every message follows <im_start>{role/name}\n{content}<im_end>\n
formatting_token_count = 4

messages = content
messages_text = ["".join(message.values()) for message in messages]
messages_text = [f"{message.role}{message.content}" for message in messages]
tokens = [enc.encode(t, disallowed_special=()) for t in messages_text]

n_tokens_list = []
for token, message in zip(tokens, messages):
n_tokens = len(token) + formatting_token_count
if "name" in message:
n_tokens += -1
n_tokens_list.append(n_tokens)
return sum(n_tokens_list)
else:
Expand Down

0 comments on commit 5341814

Please sign in to comment.