From b242c66a3b4b351ea906d8d898dbc4fef20923c6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Feb 2025 22:41:11 -0800 Subject: [PATCH] (Feat) - Add `/bedrock/invoke` support for all Anthropic models (#8383) * use anthropic transformation for bedrock/invoke * use anthropic transforms for bedrock invoke claude * TestBedrockInvokeClaudeJson * add AmazonAnthropicClaudeStreamDecoder * pass bedrock_invoke_provider to make_call * fix _get_base_bedrock_model * fix get_bedrock_route * fix bedrock routing * fixes for bedrock invoke * test_all_model_configs * fix AWSEventStreamDecoder linting * fix code qa * test_bedrock_get_base_model * test_get_model_info_bedrock_models * test_bedrock_base_model_helper * test_bedrock_route_detection --- litellm/llms/base_llm/base_utils.py | 11 ++ .../bedrock/chat/converse_transformation.py | 74 +++------ litellm/llms/bedrock/chat/invoke_handler.py | 42 ++++- .../anthropic_claude3_transformation.py | 125 +++++++++------ .../base_invoke_transformation.py | 151 ++++-------------- litellm/llms/bedrock/common_utils.py | 68 +++++++- .../llms/openai/chat/gpt_transformation.py | 4 + litellm/llms/topaz/common_utils.py | 4 + litellm/main.py | 10 +- litellm/utils.py | 65 +++++--- .../test_bedrock_completion.py | 48 +++++- .../test_bedrock_invoke_claude_json.py | 28 ++++ .../test_max_completion_tokens.py | 6 +- tests/local_testing/test_get_llm_provider.py | 8 + tests/local_testing/test_get_model_info.py | 4 +- 15 files changed, 386 insertions(+), 262 deletions(-) create mode 100644 tests/llm_translation/test_bedrock_invoke_claude_json.py diff --git a/litellm/llms/base_llm/base_utils.py b/litellm/llms/base_llm/base_utils.py index ac3d2c81f9f7..a7e65cdfbfbc 100644 --- a/litellm/llms/base_llm/base_utils.py +++ b/litellm/llms/base_llm/base_utils.py @@ -34,6 +34,17 @@ def get_api_key(api_key: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]: pass + @staticmethod + @abstractmethod + def get_base_model(model: str) -> Optional[str]: + """ + Returns the base model name from the given model name. + + Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0` + This function will return `anthropic.claude-3-opus-20240229-v1:0` + """ + pass + def _dict_to_response_format_helper( response_format: dict, ref_template: Optional[str] = None diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index aa09fb30611b..548e6f690a55 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -33,14 +33,7 @@ from litellm.types.utils import ModelResponse, Usage from litellm.utils import add_dummy_tool, has_tool_call_blocks -from ..common_utils import ( - AmazonBedrockGlobalConfig, - BedrockError, - get_bedrock_tool_name, -) - -global_config = AmazonBedrockGlobalConfig() -all_global_regions = global_config.get_all_regions() +from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name class AmazonConverseConfig(BaseConfig): @@ -104,7 +97,7 @@ def get_supported_openai_params(self, model: str) -> List[str]: ] ## Filter out 'cross-region' from model name - base_model = self._get_base_model(model) + base_model = BedrockModelInfo.get_base_model(model) if ( base_model.startswith("anthropic") @@ -341,9 +334,9 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig if "top_k" in inference_params: inference_params["topK"] = inference_params.pop("top_k") return InferenceConfig(**inference_params) - + def _handle_top_k_value(self, model: str, inference_params: dict) -> dict: - base_model = self._get_base_model(model) + base_model = BedrockModelInfo.get_base_model(model) val_top_k = None if "topK" in inference_params: @@ -352,11 +345,11 @@ def _handle_top_k_value(self, model: str, inference_params: dict) -> dict: val_top_k = inference_params.pop("top_k") if val_top_k: - if (base_model.startswith("anthropic")): + if base_model.startswith("anthropic"): return {"top_k": val_top_k} if base_model.startswith("amazon.nova"): - return {'inferenceConfig': {"topK": val_top_k}} - + return {"inferenceConfig": {"topK": val_top_k}} + return {} def _transform_request_helper( @@ -393,15 +386,25 @@ def _transform_request_helper( ) + ["top_k"] supported_tool_call_params = ["tools", "tool_choice"] supported_guardrail_params = ["guardrailConfig"] - total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params + total_supported_params = ( + supported_converse_params + + supported_tool_call_params + + supported_guardrail_params + ) inference_params.pop("json_mode", None) # used for handling json_schema # keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params' - additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params} - inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params} + additional_request_params = { + k: v for k, v in inference_params.items() if k not in total_supported_params + } + inference_params = { + k: v for k, v in inference_params.items() if k in total_supported_params + } # Only set the topK value in for models that support it - additional_request_params.update(self._handle_top_k_value(model, inference_params)) + additional_request_params.update( + self._handle_top_k_value(model, inference_params) + ) bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( inference_params.pop("tools", []) @@ -679,41 +682,6 @@ def _transform_response( return model_response - def _supported_cross_region_inference_region(self) -> List[str]: - """ - Abbreviations of regions AWS Bedrock supports for cross region inference - """ - return ["us", "eu", "apac"] - - def _get_base_model(self, model: str) -> str: - """ - Get the base model from the given model name. - - Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" - AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" - """ - - if model.startswith("bedrock/"): - model = model.split("/", 1)[1] - - if model.startswith("converse/"): - model = model.split("/", 1)[1] - - potential_region = model.split(".", 1)[0] - - alt_potential_region = model.split("/", 1)[ - 0 - ] # in model cost map we store regional information like `/us-west-2/bedrock-model` - - if potential_region in self._supported_cross_region_inference_region(): - return model.split(".", 1)[1] - elif ( - alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1 - ): - return model.split("/", 1)[1] - - return model - def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 460c2bbeaca5..db419aa1104d 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -40,6 +40,9 @@ parse_xml_params, prompt_factory, ) +from litellm.llms.anthropic.chat.handler import ( + ModelResponseIterator as AnthropicModelResponseIterator, +) from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -177,6 +180,7 @@ async def make_call( logging_obj: Logging, fake_stream: bool = False, json_mode: Optional[bool] = False, + bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None, ): try: if client is None: @@ -214,6 +218,14 @@ async def make_call( completion_stream: Any = MockResponseIterator( model_response=model_response, json_mode=json_mode ) + elif bedrock_invoke_provider == "anthropic": + decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder( + model=model, + sync_stream=False, + ) + completion_stream = decoder.aiter_bytes( + response.aiter_bytes(chunk_size=1024) + ) else: decoder = AWSEventStreamDecoder(model=model) completion_stream = decoder.aiter_bytes( @@ -248,6 +260,7 @@ def make_sync_call( logging_obj: Logging, fake_stream: bool = False, json_mode: Optional[bool] = False, + bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None, ): try: if client is None: @@ -283,6 +296,12 @@ def make_sync_call( completion_stream: Any = MockResponseIterator( model_response=model_response, json_mode=json_mode ) + elif bedrock_invoke_provider == "anthropic": + decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder( + model=model, + sync_stream=True, + ) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) else: decoder = AWSEventStreamDecoder(model=model) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) @@ -1323,7 +1342,7 @@ def _chunk_parser(self, chunk_data: dict) -> GChunk: text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore is_finished = True finish_reason = "stop" - ######## bedrock.anthropic mappings ############### + ######## converse bedrock.anthropic mappings ############### elif ( "contentBlockIndex" in chunk_data or "stopReason" in chunk_data @@ -1429,6 +1448,27 @@ def _parse_message_from_event(self, event) -> Optional[str]: return chunk.decode() # type: ignore[no-any-return] +class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder): + def __init__( + self, + model: str, + sync_stream: bool, + ) -> None: + """ + Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models + + The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method + """ + super().__init__(model=model) + self.anthropic_model_response_iterator = AnthropicModelResponseIterator( + streaming_response=None, + sync_stream=sync_stream, + ) + + def _chunk_parser(self, chunk_data: dict) -> GChunk: + return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data) + + class MockResponseIterator: # for returning ai21 streaming responses def __init__(self, model_response, json_mode: Optional[bool] = False): self.model_response = model_response diff --git a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py index ca8c0bf98127..09842aef0170 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py @@ -1,61 +1,34 @@ -import types -from typing import List, Optional +from typing import TYPE_CHECKING, Any, List, Optional +import httpx -class AmazonAnthropicClaude3Config: +import litellm +from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import ( + AmazonInvokeConfig, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class AmazonAnthropicClaude3Config(AmazonInvokeConfig): """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude https://docs.anthropic.com/claude/docs/models-overview#model-comparison Supported Params for the Amazon / Anthropic Claude 3 models: - - - `max_tokens` Required (integer) max tokens. Default is 4096 - - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" - - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py - - `temperature` Optional (float) The amount of randomness injected into the response - - `top_p` Optional (float) Use nucleus sampling. - - `top_k` Optional (int) Only sample from the top K options for each subsequent token - - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating """ - max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default - anthropic_version: Optional[str] = "bedrock-2023-05-31" - system: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - top_k: Optional[int] = None - stop_sequences: Optional[List[str]] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - anthropic_version: Optional[str] = None, - ) -> None: - locals_ = locals().copy() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } + anthropic_version: str = "bedrock-2023-05-31" - def get_supported_openai_params(self): + def get_supported_openai_params(self, model: str): return [ "max_tokens", "max_completion_tokens", @@ -68,7 +41,13 @@ def get_supported_openai_params(self): "extra_headers", ] - def map_openai_params(self, non_default_params: dict, optional_params: dict): + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ): for param, value in non_default_params.items(): if param == "max_tokens" or param == "max_completion_tokens": optional_params["max_tokens"] = value @@ -83,3 +62,53 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict): if param == "top_p": optional_params["top_p"] = value return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + _anthropic_request = litellm.AnthropicConfig().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + + _anthropic_request.pop("model", None) + if "anthropic_version" not in _anthropic_request: + _anthropic_request["anthropic_version"] = self.anthropic_version + + return _anthropic_request + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + return litellm.AnthropicConfig().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + api_key=api_key, + json_mode=json_mode, + ) diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index fbcd7660b2ed..6284a7ab088e 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -2,7 +2,6 @@ import json import time import urllib.parse -import uuid from functools import partial from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args @@ -13,11 +12,7 @@ from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.litellm_core_utils.prompt_templates.factory import ( cohere_message_pt, - construct_tool_use_system_prompt, - contains_tag, custom_prompt, - extract_between_tags, - parse_xml_params, prompt_factory, ) from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException @@ -194,7 +189,6 @@ def transform_request( # noqa: PLR0915 for k, v in inference_params.items() if k not in self.aws_authentication_params } - json_schemas: dict = {} request_data: dict = {} if provider == "cohere": if model.startswith("cohere.command-r"): @@ -223,57 +217,13 @@ def transform_request( # noqa: PLR0915 ) request_data = {"prompt": prompt, **inference_params} elif provider == "anthropic": - if model.startswith("anthropic.claude-3"): - # Separate system prompt from rest of message - system_prompt_idx: list[int] = [] - system_messages: list[str] = [] - for idx, message in enumerate(messages): - if message["role"] == "system" and isinstance( - message["content"], str - ): - system_messages.append(message["content"]) - system_prompt_idx.append(idx) - if len(system_prompt_idx) > 0: - inference_params["system"] = "\n".join(system_messages) - messages = [ - i for j, i in enumerate(messages) if j not in system_prompt_idx - ] - # Format rest of message according to anthropic guidelines - messages = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic_xml" - ) # type: ignore - ## LOAD CONFIG - config = litellm.AmazonAnthropicClaude3Config.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - ## Handle Tool Calling - if "tools" in inference_params: - _is_function_call = True - for tool in inference_params["tools"]: - json_schemas[tool["function"]["name"]] = tool["function"].get( - "parameters", None - ) - tool_calling_system_prompt = construct_tool_use_system_prompt( - tools=inference_params["tools"] - ) - inference_params["system"] = ( - inference_params.get("system", "\n") - + tool_calling_system_prompt - ) # add the anthropic tool calling prompt to the system prompt - inference_params.pop("tools") - request_data = {"messages": messages, **inference_params} - else: - ## LOAD CONFIG - config = litellm.AmazonAnthropicConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - request_data = {"prompt": prompt, **inference_params} + return litellm.AmazonAnthropicClaude3Config().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) elif provider == "ai21": ## LOAD CONFIG config = litellm.AmazonAI21Config.get_config() @@ -359,66 +309,19 @@ def transform_response( # noqa: PLR0915 completion_response["generations"][0]["finish_reason"] ) elif provider == "anthropic": - if model.startswith("anthropic.claude-3"): - json_schemas: dict = {} - _is_function_call = False - ## Handle Tool Calling - if "tools" in optional_params: - _is_function_call = True - for tool in optional_params["tools"]: - json_schemas[tool["function"]["name"]] = tool[ - "function" - ].get("parameters", None) - outputText = completion_response.get("content")[0].get("text", None) - if outputText is not None and contains_tag( - "invoke", outputText - ): # OUTPUT PARSE FUNCTION CALL - function_name = extract_between_tags("tool_name", outputText)[0] - function_arguments_str = extract_between_tags( - "invoke", outputText - )[0].strip() - function_arguments_str = ( - f"{function_arguments_str}" - ) - function_arguments = parse_xml_params( - function_arguments_str, - json_schema=json_schemas.get( - function_name, None - ), # check if we have a json schema for this function name) - ) - _message = litellm.Message( - tool_calls=[ - { - "id": f"call_{uuid.uuid4()}", - "type": "function", - "function": { - "name": function_name, - "arguments": json.dumps(function_arguments), - }, - } - ], - content=None, - ) - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = ( - outputText # allow user to access raw anthropic tool calling response - ) - model_response.choices[0].finish_reason = map_finish_reason( - completion_response.get("stop_reason", "") - ) - _usage = litellm.Usage( - prompt_tokens=completion_response["usage"]["input_tokens"], - completion_tokens=completion_response["usage"]["output_tokens"], - total_tokens=completion_response["usage"]["input_tokens"] - + completion_response["usage"]["output_tokens"], - ) - setattr(model_response, "usage", _usage) - else: - outputText = completion_response["completion"] - - model_response.choices[0].finish_reason = completion_response[ - "stop_reason" - ] + return litellm.AmazonAnthropicClaude3Config().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + api_key=api_key, + json_mode=json_mode, + ) elif provider == "ai21": outputText = ( completion_response.get("completions")[0].get("data").get("text") @@ -536,6 +439,7 @@ def get_async_custom_stream_wrapper( messages=messages, logging_obj=logging_obj, fake_stream=True if "ai21" in api_base else False, + bedrock_invoke_provider=self.get_bedrock_invoke_provider(model), ), model=model, custom_llm_provider="bedrock", @@ -569,6 +473,7 @@ def get_sync_custom_stream_wrapper( messages=messages, logging_obj=logging_obj, fake_stream=True if "ai21" in api_base else False, + bedrock_invoke_provider=self.get_bedrock_invoke_provider(model), ), model=model, custom_llm_provider="bedrock", @@ -594,10 +499,14 @@ def get_bedrock_invoke_provider( """ Helper function to get the bedrock provider from the model - handles 2 scenarions: - 1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` - 2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` + handles 3 scenarions: + 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` """ + if model.startswith("invoke/"): + model = model.replace("invoke/", "", 1) + _split_model = model.split(".")[0] if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) @@ -640,9 +549,9 @@ def get_bedrock_model_id( else: modelId = model + modelId = modelId.replace("invoke/", "", 1) if provider == "llama" and "llama/" in modelId: modelId = self._get_model_id_for_llama_like_model(modelId) - return modelId def _get_aws_region_name(self, optional_params: dict) -> str: diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index deed2124c4ad..8a534f6eac42 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -3,11 +3,12 @@ """ import os -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union import httpx import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret @@ -310,3 +311,68 @@ def get_bedrock_tool_name(response_tool_name: str) -> str: response_tool_name ] return response_tool_name + + +class BedrockModelInfo(BaseLLMModelInfo): + + global_config = AmazonBedrockGlobalConfig() + all_global_regions = global_config.get_all_regions() + + @staticmethod + def get_base_model(model: str) -> str: + """ + Get the base model from the given model name. + + Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + """ + if model.startswith("bedrock/"): + model = model.split("/", 1)[1] + + if model.startswith("converse/"): + model = model.split("/", 1)[1] + + if model.startswith("invoke/"): + model = model.split("/", 1)[1] + + potential_region = model.split(".", 1)[0] + + alt_potential_region = model.split("/", 1)[ + 0 + ] # in model cost map we store regional information like `/us-west-2/bedrock-model` + + if ( + potential_region + in BedrockModelInfo._supported_cross_region_inference_region() + ): + return model.split(".", 1)[1] + elif ( + alt_potential_region in BedrockModelInfo.all_global_regions + and len(model.split("/", 1)) > 1 + ): + return model.split("/", 1)[1] + + return model + + @staticmethod + def _supported_cross_region_inference_region() -> List[str]: + """ + Abbreviations of regions AWS Bedrock supports for cross region inference + """ + return ["us", "eu", "apac"] + + @staticmethod + def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]: + """ + Get the bedrock route for the given model. + """ + base_model = BedrockModelInfo.get_base_model(model) + if "invoke/" in model: + return "invoke" + elif "converse_like" in model: + return "converse_like" + elif "converse/" in model: + return "converse" + elif base_model in litellm.bedrock_converse_models: + return "converse" + return "invoke" diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 98c3254da4a0..84a57bbaa65d 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -344,6 +344,10 @@ def get_api_base(api_base: Optional[str] = None) -> Optional[str]: or "https://api.openai.com/v1" ) + @staticmethod + def get_base_model(model: str) -> str: + return model + def get_model_response_iterator( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], diff --git a/litellm/llms/topaz/common_utils.py b/litellm/llms/topaz/common_utils.py index fc3c69a750f7..4ef2315db4ee 100644 --- a/litellm/llms/topaz/common_utils.py +++ b/litellm/llms/topaz/common_utils.py @@ -29,3 +29,7 @@ def get_api_base(api_base: Optional[str] = None) -> Optional[str]: return ( api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com" ) + + @staticmethod + def get_base_model(model: str) -> str: + return model diff --git a/litellm/main.py b/litellm/main.py index 9c63696ac291..2da5795fa28f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -68,6 +68,7 @@ get_content_from_model_response, ) from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.bedrock.common_utils import BedrockModelInfo from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.realtime_api.main import _realtime_health_check from litellm.secret_managers.main import get_secret_str @@ -2628,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915 aws_bedrock_client.meta.region_name ) - base_model = litellm.AmazonConverseConfig()._get_base_model(model) - - if base_model in litellm.bedrock_converse_models or model.startswith( - "converse/" - ): + bedrock_route = BedrockModelInfo.get_bedrock_route(model) + if bedrock_route == "converse": model = model.replace("converse/", "") response = bedrock_converse_chat_completion.completion( model=model, @@ -2651,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915 client=client, api_base=api_base, ) - elif "converse_like" in model: + elif bedrock_route == "converse_like": model = model.replace("converse_like/", "") response = base_llm_http_handler.completion( model=model, diff --git a/litellm/utils.py b/litellm/utils.py index 7d73548f89ab..7cdfc2ebbe46 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -86,10 +86,10 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( LiteLLMResponseObjectHandler, _handle_invalid_parallel_tool_calls, + _parse_content_for_reasoning, convert_to_model_response_object, convert_to_streaming_response, convert_to_streaming_response_async, - _parse_content_for_reasoning, ) from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import ( @@ -111,6 +111,7 @@ calculate_img_tokens, get_modified_max_tokens, ) +from litellm.llms.bedrock.common_utils import BedrockModelInfo from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.router_utils.get_retry_from_policy import ( get_num_retries_from_retry_policy, @@ -3189,8 +3190,8 @@ def _check_valid_arg(supported_params: List[str]): ), ) elif custom_llm_provider == "bedrock": - base_model = litellm.AmazonConverseConfig()._get_base_model(model) - if base_model in litellm.bedrock_converse_models: + bedrock_route = BedrockModelInfo.get_bedrock_route(model) + if bedrock_route == "converse" or bedrock_route == "converse_like": optional_params = litellm.AmazonConverseConfig().map_openai_params( model=model, non_default_params=non_default_params, @@ -3203,15 +3204,20 @@ def _check_valid_arg(supported_params: List[str]): messages=messages, ) - elif "anthropic" in model: - if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route. - if model.startswith("anthropic.claude-3"): - optional_params = ( - litellm.AmazonAnthropicClaude3Config().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) + elif "anthropic" in model and bedrock_route == "invoke": + if model.startswith("anthropic.claude-3"): + optional_params = ( + litellm.AmazonAnthropicClaude3Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) + ) else: optional_params = litellm.AmazonAnthropicConfig().map_openai_params( non_default_params=non_default_params, @@ -3972,8 +3978,16 @@ def _strip_stable_vertex_version(model_name) -> str: return re.sub(r"-\d+$", "", model_name) -def _strip_bedrock_region(model_name) -> str: - return litellm.AmazonConverseConfig()._get_base_model(model_name) +def _get_base_bedrock_model(model_name) -> str: + """ + Get the base model from the given model name. + + Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" + """ + from litellm.llms.bedrock.common_utils import BedrockModelInfo + + return BedrockModelInfo.get_base_model(model_name) def _strip_openai_finetune_model_name(model_name: str) -> str: @@ -3994,8 +4008,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str: def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str: if custom_llm_provider and custom_llm_provider == "bedrock": - strip_bedrock_region = _strip_bedrock_region(model_name=model) - return strip_bedrock_region + stripped_bedrock_model = _get_base_bedrock_model(model_name=model) + return stripped_bedrock_model elif custom_llm_provider and ( custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini" ): @@ -6066,24 +6080,23 @@ def get_provider_chat_config( # noqa: PLR0915 elif litellm.LlmProviders.PETALS == provider: return litellm.PetalsConfig() elif litellm.LlmProviders.BEDROCK == provider: - base_model = litellm.AmazonConverseConfig()._get_base_model(model) - bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model) - if ( - base_model in litellm.bedrock_converse_models - or "converse_like" in model - ): + bedrock_route = BedrockModelInfo.get_bedrock_route(model) + bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider( + model + ) + if bedrock_route == "converse" or bedrock_route == "converse_like": return litellm.AmazonConverseConfig() - elif bedrock_provider == "amazon": # amazon titan llms + elif bedrock_invoke_provider == "amazon": # amazon titan llms return litellm.AmazonTitanConfig() elif ( - bedrock_provider == "meta" or bedrock_provider == "llama" + bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama" ): # amazon / meta llms return litellm.AmazonLlamaConfig() - elif bedrock_provider == "ai21": # ai21 llms + elif bedrock_invoke_provider == "ai21": # ai21 llms return litellm.AmazonAI21Config() - elif bedrock_provider == "cohere": # cohere models on bedrock + elif bedrock_invoke_provider == "cohere": # cohere models on bedrock return litellm.AmazonCohereConfig() - elif bedrock_provider == "mistral": # mistral models on bedrock + elif bedrock_invoke_provider == "mistral": # mistral models on bedrock return litellm.AmazonMistralConfig() else: return litellm.AmazonInvokeConfig() diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index f09c4b45a562..4d10b724fdc0 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -1265,7 +1265,9 @@ def test_bedrock_cross_region_inference(model): ], ) def test_bedrock_get_base_model(model, expected_base_model): - assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model + from litellm.llms.bedrock.common_utils import BedrockModelInfo + + assert BedrockModelInfo.get_base_model(model) == expected_base_model from litellm.litellm_core_utils.prompt_templates.factory import ( @@ -1982,9 +1984,49 @@ def test_bedrock_mapped_converse_models(): def test_bedrock_base_model_helper(): + from litellm.llms.bedrock.common_utils import BedrockModelInfo + model = "us.amazon.nova-pro-v1:0" - litellm.AmazonConverseConfig()._get_base_model(model) - assert model == "us.amazon.nova-pro-v1:0" + base_model = BedrockModelInfo.get_base_model(model) + assert base_model == "amazon.nova-pro-v1:0" + + assert ( + BedrockModelInfo.get_base_model( + "invoke/anthropic.claude-3-5-sonnet-20241022-v2:0" + ) + == "anthropic.claude-3-5-sonnet-20241022-v2:0" + ) + + +@pytest.mark.parametrize( + "model,expected_route", + [ + # Test explicit route prefixes + ("invoke/anthropic.claude-3-sonnet-20240229-v1:0", "invoke"), + ("converse/anthropic.claude-3-sonnet-20240229-v1:0", "converse"), + ("converse_like/anthropic.claude-3-sonnet-20240229-v1:0", "converse_like"), + # Test models in BEDROCK_CONVERSE_MODELS list + ("anthropic.claude-3-5-haiku-20241022-v1:0", "converse"), + ("anthropic.claude-v2", "converse"), + ("meta.llama3-70b-instruct-v1:0", "converse"), + ("mistral.mistral-large-2407-v1:0", "converse"), + # Test models with region prefixes + ("us.anthropic.claude-3-sonnet-20240229-v1:0", "converse"), + ("us.meta.llama3-70b-instruct-v1:0", "converse"), + # Test default case (should return "invoke") + ("amazon.titan-text-express-v1", "invoke"), + ("cohere.command-text-v14", "invoke"), + ("cohere.command-r-v1:0", "invoke"), + ], +) +def test_bedrock_route_detection(model, expected_route): + """Test all scenarios for BedrockModelInfo.get_bedrock_route""" + from litellm.llms.bedrock.common_utils import BedrockModelInfo + + route = BedrockModelInfo.get_bedrock_route(model) + assert ( + route == expected_route + ), f"Expected route '{expected_route}' for model '{model}', but got '{route}'" @pytest.mark.parametrize( diff --git a/tests/llm_translation/test_bedrock_invoke_claude_json.py b/tests/llm_translation/test_bedrock_invoke_claude_json.py new file mode 100644 index 000000000000..2e943ed6822b --- /dev/null +++ b/tests/llm_translation/test_bedrock_invoke_claude_json.py @@ -0,0 +1,28 @@ +from base_llm_unit_tests import BaseLLMChatTest +import pytest +import sys +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +class TestBedrockInvokeClaudeJson(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + litellm._turn_on_debug() + return { + "model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0", + } + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pass + + @pytest.fixture(autouse=True) + def skip_non_json_tests(self, request): + if not "json" in request.function.__name__.lower(): + pytest.skip( + f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'" + ) diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index 04bce96222b9..a15934e13427 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -305,12 +305,16 @@ def test_all_model_configs(): assert ( "max_completion_tokens" - in AmazonAnthropicClaude3Config().get_supported_openai_params() + in AmazonAnthropicClaude3Config().get_supported_openai_params( + model="anthropic.claude-3-sonnet-20240229-v1:0" + ) ) assert AmazonAnthropicClaude3Config().map_openai_params( non_default_params={"max_completion_tokens": 10}, optional_params={}, + model="anthropic.claude-3-sonnet-20240229-v1:0", + drop_params=False, ) == {"max_tokens": 10} assert ( diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index a9be17807271..c3f4c15c270f 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -208,3 +208,11 @@ def test_nova_bedrock_converse(): ) assert custom_llm_provider == "bedrock" assert model == "amazon.nova-micro-v1:0" + + +def test_bedrock_invoke_anthropic(): + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0", + ) + assert custom_llm_provider == "bedrock" + assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0" diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index 6747efce3a53..c879332c7b5f 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -321,7 +321,7 @@ def test_get_model_info_bedrock_models(): """ Check for drift in base model info for bedrock models and regional model info for bedrock models. """ - from litellm import AmazonConverseConfig + from litellm.llms.bedrock.common_utils import BedrockModelInfo os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") @@ -337,7 +337,7 @@ def test_get_model_info_bedrock_models(): if any(commitment in k for commitment in potential_commitments): for commitment in potential_commitments: k = k.replace(f"{commitment}/", "") - base_model = AmazonConverseConfig()._get_base_model(k) + base_model = BedrockModelInfo.get_base_model(k) base_model_info = litellm.model_cost[base_model] for base_model_key, base_model_value in base_model_info.items(): if base_model_key.startswith("supports_"):