Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bedrock token count and IDs for Anthropic models #341

Merged
merged 10 commits into from
Feb 7, 2025
26 changes: 18 additions & 8 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
import warnings
from collections import defaultdict
from operator import itemgetter
from typing import (
Expand Down Expand Up @@ -51,6 +52,7 @@
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
anthropic_tokens_supported,
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
Expand Down Expand Up @@ -620,16 +622,24 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return final_output

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_num_tokens_anthropic(text)
return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_token_ids_anthropic(text)
else:
warnings.warn(
f"Falling back to default token method due to missing or incompatible `anthropic` installation "
f"(needs <=0.38.0).\n\nIf using `anthropic>0.38.0`, it is recommended to provide the model "
f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. "
f"For get_num_tokens, as another alternative, you can implement your own token counter method "
f"using the ChatAnthropic or AnthropicLLM classes."
)
return super().get_token_ids(text)

def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None:
"""Workaround to bind. Sets the system prompt with tools"""
Expand Down
25 changes: 17 additions & 8 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from langchain_aws.function_calling import _tools_in_params
from langchain_aws.utils import (
anthropic_tokens_supported,
enforce_stop_tokens,
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -1301,13 +1302,21 @@ async def _acall(
return "".join([chunk.text for chunk in chunks])

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_num_tokens_anthropic(text)
return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_token_ids_anthropic(text)
else:
warnings.warn(
f"Falling back to default token method due to missing or incompatible `anthropic` installation "
f"(needs <=0.38.0).\n\nFor `anthropic>0.38.0`, it is recommended to provide the model "
f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. "
f"For get_num_tokens, as another alternative, you can implement your own token counter method "
f"using the ChatAnthropic or AnthropicLLM classes."
)
return super().get_token_ids(text)
30 changes: 24 additions & 6 deletions libs/aws/langchain_aws/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
import re
from typing import Any, List

from packaging import version


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]


def _get_anthropic_client() -> Any:
def anthropic_tokens_supported() -> bool:
"""Check if we have all requirements for Anthropic count_tokens() and get_tokenizer()."""
try:
import anthropic
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"This is needed in order to accurately tokenize the text "
"for anthropic models. Please install it with `pip install anthropic`."
)
return False

anthropic_version = version.parse(anthropic.__version__)
if anthropic_version > version.parse("0.38.0"):
return False
else:
httpx_import_msg = "httpx<=0.27.2 is required."
try:
import httpx
except ImportError:
raise ImportError(httpx_import_msg)
httpx_version = version.parse(httpx.__version__)
if httpx_version > version.parse("0.27.2"):
raise ImportError(httpx_import_msg)
else:
return True


def _get_anthropic_client() -> Any:
import anthropic
return anthropic.Anthropic()


Expand Down