Skip to content

Commit 6678dff

Browse files
committed
fix get_bedrock_route
1 parent c4d04e7 commit 6678dff

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

litellm/llms/bedrock/common_utils.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import os
6-
from typing import List, Optional, Union
6+
from typing import List, Literal, Optional, Union
77

88
import httpx
99

@@ -360,3 +360,19 @@ def _supported_cross_region_inference_region() -> List[str]:
360360
Abbreviations of regions AWS Bedrock supports for cross region inference
361361
"""
362362
return ["us", "eu", "apac"]
363+
364+
@staticmethod
365+
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
366+
"""
367+
Get the bedrock route for the given model.
368+
"""
369+
base_model = BedrockModelInfo.get_base_model(model)
370+
if "invoke/" in model:
371+
return "invoke"
372+
elif "converse_like" in model:
373+
return "converse_like"
374+
elif "converse/" in model:
375+
return "converse"
376+
elif base_model in litellm.bedrock_converse_models:
377+
return "converse"
378+
return "invoke"

litellm/main.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
get_content_from_model_response,
6969
)
7070
from litellm.llms.base_llm.chat.transformation import BaseConfig
71+
from litellm.llms.bedrock.common_utils import BedrockModelInfo
7172
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
7273
from litellm.realtime_api.main import _realtime_health_check
7374
from litellm.secret_managers.main import get_secret_str
@@ -2628,11 +2629,8 @@ def completion( # type: ignore # noqa: PLR0915
26282629
aws_bedrock_client.meta.region_name
26292630
)
26302631

2631-
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
2632-
2633-
if base_model in litellm.bedrock_converse_models or model.startswith(
2634-
"converse/"
2635-
):
2632+
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
2633+
if bedrock_route == "converse":
26362634
model = model.replace("converse/", "")
26372635
response = bedrock_converse_chat_completion.completion(
26382636
model=model,
@@ -2651,7 +2649,7 @@ def completion( # type: ignore # noqa: PLR0915
26512649
client=client,
26522650
api_base=api_base,
26532651
)
2654-
elif "converse_like" in model:
2652+
elif bedrock_route == "converse_like":
26552653
model = model.replace("converse_like/", "")
26562654
response = base_llm_http_handler.completion(
26572655
model=model,

0 commit comments

Comments
 (0)