Skip to content

Commit b242c66

Browse files
authored
(Feat) - Add /bedrock/invoke support for all Anthropic models (#8383)
* use anthropic transformation for bedrock/invoke * use anthropic transforms for bedrock invoke claude * TestBedrockInvokeClaudeJson * add AmazonAnthropicClaudeStreamDecoder * pass bedrock_invoke_provider to make_call * fix _get_base_bedrock_model * fix get_bedrock_route * fix bedrock routing * fixes for bedrock invoke * test_all_model_configs * fix AWSEventStreamDecoder linting * fix code qa * test_bedrock_get_base_model * test_get_model_info_bedrock_models * test_bedrock_base_model_helper * test_bedrock_route_detection
1 parent 1dd3713 commit b242c66

15 files changed

+386
-262
lines changed

litellm/llms/base_llm/base_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
3434
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
3535
pass
3636

37+
@staticmethod
38+
@abstractmethod
39+
def get_base_model(model: str) -> Optional[str]:
40+
"""
41+
Returns the base model name from the given model name.
42+
43+
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
44+
This function will return `anthropic.claude-3-opus-20240229-v1:0`
45+
"""
46+
pass
47+
3748

3849
def _dict_to_response_format_helper(
3950
response_format: dict, ref_template: Optional[str] = None

litellm/llms/bedrock/chat/converse_transformation.py

Lines changed: 21 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,7 @@
3333
from litellm.types.utils import ModelResponse, Usage
3434
from litellm.utils import add_dummy_tool, has_tool_call_blocks
3535

36-
from ..common_utils import (
37-
AmazonBedrockGlobalConfig,
38-
BedrockError,
39-
get_bedrock_tool_name,
40-
)
41-
42-
global_config = AmazonBedrockGlobalConfig()
43-
all_global_regions = global_config.get_all_regions()
36+
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
4437

4538

4639
class AmazonConverseConfig(BaseConfig):
@@ -104,7 +97,7 @@ def get_supported_openai_params(self, model: str) -> List[str]:
10497
]
10598

10699
## Filter out 'cross-region' from model name
107-
base_model = self._get_base_model(model)
100+
base_model = BedrockModelInfo.get_base_model(model)
108101

109102
if (
110103
base_model.startswith("anthropic")
@@ -341,9 +334,9 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
341334
if "top_k" in inference_params:
342335
inference_params["topK"] = inference_params.pop("top_k")
343336
return InferenceConfig(**inference_params)
344-
337+
345338
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
346-
base_model = self._get_base_model(model)
339+
base_model = BedrockModelInfo.get_base_model(model)
347340

348341
val_top_k = None
349342
if "topK" in inference_params:
@@ -352,11 +345,11 @@ def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
352345
val_top_k = inference_params.pop("top_k")
353346

354347
if val_top_k:
355-
if (base_model.startswith("anthropic")):
348+
if base_model.startswith("anthropic"):
356349
return {"top_k": val_top_k}
357350
if base_model.startswith("amazon.nova"):
358-
return {'inferenceConfig': {"topK": val_top_k}}
359-
351+
return {"inferenceConfig": {"topK": val_top_k}}
352+
360353
return {}
361354

362355
def _transform_request_helper(
@@ -393,15 +386,25 @@ def _transform_request_helper(
393386
) + ["top_k"]
394387
supported_tool_call_params = ["tools", "tool_choice"]
395388
supported_guardrail_params = ["guardrailConfig"]
396-
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
389+
total_supported_params = (
390+
supported_converse_params
391+
+ supported_tool_call_params
392+
+ supported_guardrail_params
393+
)
397394
inference_params.pop("json_mode", None) # used for handling json_schema
398395

399396
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
400-
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
401-
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
397+
additional_request_params = {
398+
k: v for k, v in inference_params.items() if k not in total_supported_params
399+
}
400+
inference_params = {
401+
k: v for k, v in inference_params.items() if k in total_supported_params
402+
}
402403

403404
# Only set the topK value in for models that support it
404-
additional_request_params.update(self._handle_top_k_value(model, inference_params))
405+
additional_request_params.update(
406+
self._handle_top_k_value(model, inference_params)
407+
)
405408

406409
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
407410
inference_params.pop("tools", [])
@@ -679,41 +682,6 @@ def _transform_response(
679682

680683
return model_response
681684

682-
def _supported_cross_region_inference_region(self) -> List[str]:
683-
"""
684-
Abbreviations of regions AWS Bedrock supports for cross region inference
685-
"""
686-
return ["us", "eu", "apac"]
687-
688-
def _get_base_model(self, model: str) -> str:
689-
"""
690-
Get the base model from the given model name.
691-
692-
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
693-
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
694-
"""
695-
696-
if model.startswith("bedrock/"):
697-
model = model.split("/", 1)[1]
698-
699-
if model.startswith("converse/"):
700-
model = model.split("/", 1)[1]
701-
702-
potential_region = model.split(".", 1)[0]
703-
704-
alt_potential_region = model.split("/", 1)[
705-
0
706-
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
707-
708-
if potential_region in self._supported_cross_region_inference_region():
709-
return model.split(".", 1)[1]
710-
elif (
711-
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
712-
):
713-
return model.split("/", 1)[1]
714-
715-
return model
716-
717685
def get_error_class(
718686
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
719687
) -> BaseLLMException:

litellm/llms/bedrock/chat/invoke_handler.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
parse_xml_params,
4141
prompt_factory,
4242
)
43+
from litellm.llms.anthropic.chat.handler import (
44+
ModelResponseIterator as AnthropicModelResponseIterator,
45+
)
4346
from litellm.llms.custom_httpx.http_handler import (
4447
AsyncHTTPHandler,
4548
HTTPHandler,
@@ -177,6 +180,7 @@ async def make_call(
177180
logging_obj: Logging,
178181
fake_stream: bool = False,
179182
json_mode: Optional[bool] = False,
183+
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
180184
):
181185
try:
182186
if client is None:
@@ -214,6 +218,14 @@ async def make_call(
214218
completion_stream: Any = MockResponseIterator(
215219
model_response=model_response, json_mode=json_mode
216220
)
221+
elif bedrock_invoke_provider == "anthropic":
222+
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
223+
model=model,
224+
sync_stream=False,
225+
)
226+
completion_stream = decoder.aiter_bytes(
227+
response.aiter_bytes(chunk_size=1024)
228+
)
217229
else:
218230
decoder = AWSEventStreamDecoder(model=model)
219231
completion_stream = decoder.aiter_bytes(
@@ -248,6 +260,7 @@ def make_sync_call(
248260
logging_obj: Logging,
249261
fake_stream: bool = False,
250262
json_mode: Optional[bool] = False,
263+
bedrock_invoke_provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL] = None,
251264
):
252265
try:
253266
if client is None:
@@ -283,6 +296,12 @@ def make_sync_call(
283296
completion_stream: Any = MockResponseIterator(
284297
model_response=model_response, json_mode=json_mode
285298
)
299+
elif bedrock_invoke_provider == "anthropic":
300+
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
301+
model=model,
302+
sync_stream=True,
303+
)
304+
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
286305
else:
287306
decoder = AWSEventStreamDecoder(model=model)
288307
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@@ -1323,7 +1342,7 @@ def _chunk_parser(self, chunk_data: dict) -> GChunk:
13231342
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
13241343
is_finished = True
13251344
finish_reason = "stop"
1326-
######## bedrock.anthropic mappings ###############
1345+
######## converse bedrock.anthropic mappings ###############
13271346
elif (
13281347
"contentBlockIndex" in chunk_data
13291348
or "stopReason" in chunk_data
@@ -1429,6 +1448,27 @@ def _parse_message_from_event(self, event) -> Optional[str]:
14291448
return chunk.decode() # type: ignore[no-any-return]
14301449

14311450

1451+
class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder):
1452+
def __init__(
1453+
self,
1454+
model: str,
1455+
sync_stream: bool,
1456+
) -> None:
1457+
"""
1458+
Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models
1459+
1460+
The only difference between AWSEventStreamDecoder and AmazonAnthropicClaudeStreamDecoder is the `chunk_parser` method
1461+
"""
1462+
super().__init__(model=model)
1463+
self.anthropic_model_response_iterator = AnthropicModelResponseIterator(
1464+
streaming_response=None,
1465+
sync_stream=sync_stream,
1466+
)
1467+
1468+
def _chunk_parser(self, chunk_data: dict) -> GChunk:
1469+
return self.anthropic_model_response_iterator.chunk_parser(chunk=chunk_data)
1470+
1471+
14321472
class MockResponseIterator: # for returning ai21 streaming responses
14331473
def __init__(self, model_response, json_mode: Optional[bool] = False):
14341474
self.model_response = model_response
Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,34 @@
1-
import types
2-
from typing import List, Optional
1+
from typing import TYPE_CHECKING, Any, List, Optional
32

3+
import httpx
44

5-
class AmazonAnthropicClaude3Config:
5+
import litellm
6+
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
7+
AmazonInvokeConfig,
8+
)
9+
from litellm.types.llms.openai import AllMessageValues
10+
from litellm.types.utils import ModelResponse
11+
12+
if TYPE_CHECKING:
13+
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
14+
15+
LiteLLMLoggingObj = _LiteLLMLoggingObj
16+
else:
17+
LiteLLMLoggingObj = Any
18+
19+
20+
class AmazonAnthropicClaude3Config(AmazonInvokeConfig):
621
"""
722
Reference:
823
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
924
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
1025
1126
Supported Params for the Amazon / Anthropic Claude 3 models:
12-
13-
- `max_tokens` Required (integer) max tokens. Default is 4096
14-
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
15-
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
16-
- `temperature` Optional (float) The amount of randomness injected into the response
17-
- `top_p` Optional (float) Use nucleus sampling.
18-
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
19-
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
2027
"""
2128

22-
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
23-
anthropic_version: Optional[str] = "bedrock-2023-05-31"
24-
system: Optional[str] = None
25-
temperature: Optional[float] = None
26-
top_p: Optional[float] = None
27-
top_k: Optional[int] = None
28-
stop_sequences: Optional[List[str]] = None
29-
30-
def __init__(
31-
self,
32-
max_tokens: Optional[int] = None,
33-
anthropic_version: Optional[str] = None,
34-
) -> None:
35-
locals_ = locals().copy()
36-
for key, value in locals_.items():
37-
if key != "self" and value is not None:
38-
setattr(self.__class__, key, value)
39-
40-
@classmethod
41-
def get_config(cls):
42-
return {
43-
k: v
44-
for k, v in cls.__dict__.items()
45-
if not k.startswith("__")
46-
and not isinstance(
47-
v,
48-
(
49-
types.FunctionType,
50-
types.BuiltinFunctionType,
51-
classmethod,
52-
staticmethod,
53-
),
54-
)
55-
and v is not None
56-
}
29+
anthropic_version: str = "bedrock-2023-05-31"
5730

58-
def get_supported_openai_params(self):
31+
def get_supported_openai_params(self, model: str):
5932
return [
6033
"max_tokens",
6134
"max_completion_tokens",
@@ -68,7 +41,13 @@ def get_supported_openai_params(self):
6841
"extra_headers",
6942
]
7043

71-
def map_openai_params(self, non_default_params: dict, optional_params: dict):
44+
def map_openai_params(
45+
self,
46+
non_default_params: dict,
47+
optional_params: dict,
48+
model: str,
49+
drop_params: bool,
50+
):
7251
for param, value in non_default_params.items():
7352
if param == "max_tokens" or param == "max_completion_tokens":
7453
optional_params["max_tokens"] = value
@@ -83,3 +62,53 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict):
8362
if param == "top_p":
8463
optional_params["top_p"] = value
8564
return optional_params
65+
66+
def transform_request(
67+
self,
68+
model: str,
69+
messages: List[AllMessageValues],
70+
optional_params: dict,
71+
litellm_params: dict,
72+
headers: dict,
73+
) -> dict:
74+
_anthropic_request = litellm.AnthropicConfig().transform_request(
75+
model=model,
76+
messages=messages,
77+
optional_params=optional_params,
78+
litellm_params=litellm_params,
79+
headers=headers,
80+
)
81+
82+
_anthropic_request.pop("model", None)
83+
if "anthropic_version" not in _anthropic_request:
84+
_anthropic_request["anthropic_version"] = self.anthropic_version
85+
86+
return _anthropic_request
87+
88+
def transform_response(
89+
self,
90+
model: str,
91+
raw_response: httpx.Response,
92+
model_response: ModelResponse,
93+
logging_obj: LiteLLMLoggingObj,
94+
request_data: dict,
95+
messages: List[AllMessageValues],
96+
optional_params: dict,
97+
litellm_params: dict,
98+
encoding: Any,
99+
api_key: Optional[str] = None,
100+
json_mode: Optional[bool] = None,
101+
) -> ModelResponse:
102+
return litellm.AnthropicConfig().transform_response(
103+
model=model,
104+
raw_response=raw_response,
105+
model_response=model_response,
106+
logging_obj=logging_obj,
107+
request_data=request_data,
108+
messages=messages,
109+
optional_params=optional_params,
110+
litellm_params=litellm_params,
111+
encoding=encoding,
112+
api_key=api_key,
113+
json_mode=json_mode,
114+
)

0 commit comments

Comments
 (0)