Skip to content

Commit

Permalink
(Feat) - Add /bedrock/invoke support for all Anthropic models (#8383)
Browse files Browse the repository at this point in the history
* use anthropic transformation for bedrock/invoke

* use anthropic transforms for bedrock invoke claude

* TestBedrockInvokeClaudeJson

* add AmazonAnthropicClaudeStreamDecoder

* pass bedrock_invoke_provider to make_call

* fix _get_base_bedrock_model

* fix get_bedrock_route

* fix bedrock routing

* fixes for bedrock invoke

* test_all_model_configs

* fix AWSEventStreamDecoder linting

* fix code qa

* test_bedrock_get_base_model

* test_get_model_info_bedrock_models

* test_bedrock_base_model_helper

* test_bedrock_route_detection
  • Loading branch information
ishaan-jaff authored Feb 8, 2025
1 parent 1dd3713 commit b242c66
Show file tree
Hide file tree
Showing 15 changed files with 386 additions and 262 deletions.
11 changes: 11 additions & 0 deletions litellm/llms/base_llm/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass

@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass


def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None
Expand Down
74 changes: 21 additions & 53 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,7 @@
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import add_dummy_tool, has_tool_call_blocks

from ..common_utils import (
AmazonBedrockGlobalConfig,
BedrockError,
get_bedrock_tool_name,
)

global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name


class AmazonConverseConfig(BaseConfig):
Expand Down Expand Up @@ -104,7 +97,7 @@ def get_supported_openai_params(self, model: str) -> List[str]:
]

## Filter out 'cross-region' from model name
base_model = self._get_base_model(model)
base_model = BedrockModelInfo.get_base_model(model)

if (
base_model.startswith("anthropic")
Expand Down Expand Up @@ -341,9 +334,9 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
if "top_k" in inference_params:
inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params)

def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
base_model = self._get_base_model(model)
base_model = BedrockModelInfo.get_base_model(model)

val_top_k = None
if "topK" in inference_params:
Expand All @@ -352,11 +345,11 @@ def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
val_top_k = inference_params.pop("top_k")

if val_top_k:
if (base_model.startswith("anthropic")):
if base_model.startswith("anthropic"):
return {"top_k": val_top_k}
if base_model.startswith("amazon.nova"):
return {'inferenceConfig': {"topK": val_top_k}}
return {"inferenceConfig": {"topK": val_top_k}}

return {}

def _transform_request_helper(
Expand Down Expand Up @@ -393,15 +386,25 @@ def _transform_request_helper(
) + ["top_k"]
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
total_supported_params = (
supported_converse_params
+ supported_tool_call_params
+ supported_guardrail_params
)
inference_params.pop("json_mode", None) # used for handling json_schema

# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
additional_request_params = {
k: v for k, v in inference_params.items() if k not in total_supported_params
}
inference_params = {
k: v for k, v in inference_params.items() if k in total_supported_params
}

# Only set the topK value in for models that support it
additional_request_params.update(self._handle_top_k_value(model, inference_params))
additional_request_params.update(
self._handle_top_k_value(model, inference_params)
)

bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
Expand Down Expand Up @@ -679,41 +682,6 @@ def _transform_response(

return model_response

def _supported_cross_region_inference_region(self) -> List[str]:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu", "apac"]

def _get_base_model(self, model: str) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""

if model.startswith("bedrock/"):
model = model.split("/", 1)[1]

if model.startswith("converse/"):
model = model.split("/", 1)[1]

potential_region = model.split(".", 1)[0]

alt_potential_region = model.split("/", 1)[
0
] # in model cost map we store regional information like `/us-west-2/bedrock-model`

if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
elif (
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
):
return model.split("/", 1)[1]

return model

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
Expand Down
42 changes: 41 additions & 1 deletion litellm/llms/bedrock/chat/invoke_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
parse_xml_params,
prompt_factory,
)
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
Expand Down Expand Up @@ -177,6 +180,7 @@ async def make_call(
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
):
try:
if client is None:
Expand Down Expand Up @@ -214,6 +218,14 @@ async def make_call(
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=False,
)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
Expand Down Expand Up @@ -248,6 +260,7 @@ def make_sync_call(
logging_obj: Logging,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
):
try:
if client is None:
Expand Down Expand Up @@ -283,6 +296,12 @@ def make_sync_call(
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
elif bedrock_invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=True,
)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
Expand Down Expand Up @@ -1323,7 +1342,7 @@ def _chunk_parser(self, chunk_data: dict) -> GChunk:
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
######## converse bedrock.anthropic mappings ###############
elif (
"contentBlockIndex" in chunk_data
or "stopReason" in chunk_data
Expand Down Expand Up @@ -1429,6 +1448,27 @@ def _parse_message_from_event(self, event) -> Optional[str]:
return chunk.decode() # type: ignore[no-any-return]


class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
def __init__(
self,
model: str,
sync_stream: bool,
) -> None:
"""
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
"""
super().__init__(model=model)
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=sync_stream,
)

def _chunk_parser(self, chunk_data: dict) -> GChunk:
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)


class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response
Expand Down
Original file line number Diff line number Diff line change
@@ -1,61 +1,34 @@
import types
from typing import List, Optional
from typing import TYPE_CHECKING, Any, List, Optional

import httpx

class AmazonAnthropicClaude3Config:
import litellm
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any


class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` Required (integer) max tokens. Default is 4096
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
"""

max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
anthropic_version: Optional[str] = "bedrock-2023-05-31"
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None

def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
anthropic_version: str = "bedrock-2023-05-31"

def get_supported_openai_params(self):
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
Expand All @@ -68,7 +41,13 @@ def get_supported_openai_params(self):
"extra_headers",
]

def map_openai_params(self, non_default_params: dict, optional_params: dict):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
Expand All @@ -83,3 +62,53 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict):
if param == "top_p":
optional_params["top_p"] = value
return optional_params

def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_anthropic_request = litellm.AnthropicConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)

_anthropic_request.pop("model", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version

return _anthropic_request

def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return litellm.AnthropicConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
Loading

0 comments on commit b242c66

Please sign in to comment.