Skip to content

Commit c4d04e7

Browse files
committed
fix _get_base_bedrock_model
1 parent 28e5812 commit c4d04e7

File tree

6 files changed

+112
-54
lines changed

6 files changed

+112
-54
lines changed

litellm/llms/base_llm/base_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
3434
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
3535
pass
3636

37+
@staticmethod
38+
@abstractmethod
39+
def get_base_model(model: str) -> Optional[str]:
40+
"""
41+
Returns the base model name from the given model name.
42+
43+
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
44+
This function will return `anthropic.claude-3-opus-20240229-v1:0`
45+
"""
46+
pass
47+
3748

3849
def _dict_to_response_format_helper(
3950
response_format: dict, ref_template: Optional[str] = None

litellm/llms/bedrock/chat/converse_transformation.py

+21-48
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,10 @@
3636
from ..common_utils import (
3737
AmazonBedrockGlobalConfig,
3838
BedrockError,
39+
BedrockModelInfo,
3940
get_bedrock_tool_name,
4041
)
4142

42-
global_config = AmazonBedrockGlobalConfig()
43-
all_global_regions = global_config.get_all_regions()
44-
4543

4644
class AmazonConverseConfig(BaseConfig):
4745
"""
@@ -104,7 +102,7 @@ def get_supported_openai_params(self, model: str) -> List[str]:
104102
]
105103

106104
## Filter out 'cross-region' from model name
107-
base_model = self._get_base_model(model)
105+
base_model = BedrockModelInfo.get_base_model(model)
108106

109107
if (
110108
base_model.startswith("anthropic")
@@ -341,9 +339,9 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
341339
if "top_k" in inference_params:
342340
inference_params["topK"] = inference_params.pop("top_k")
343341
return InferenceConfig(**inference_params)
344-
342+
345343
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
346-
base_model = self._get_base_model(model)
344+
base_model = BedrockModelInfo.get_base_model(model)
347345

348346
val_top_k = None
349347
if "topK" in inference_params:
@@ -352,11 +350,11 @@ def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
352350
val_top_k = inference_params.pop("top_k")
353351

354352
if val_top_k:
355-
if (base_model.startswith("anthropic")):
353+
if base_model.startswith("anthropic"):
356354
return {"top_k": val_top_k}
357355
if base_model.startswith("amazon.nova"):
358-
return {'inferenceConfig': {"topK": val_top_k}}
359-
356+
return {"inferenceConfig": {"topK": val_top_k}}
357+
360358
return {}
361359

362360
def _transform_request_helper(
@@ -393,15 +391,25 @@ def _transform_request_helper(
393391
) + ["top_k"]
394392
supported_tool_call_params = ["tools", "tool_choice"]
395393
supported_guardrail_params = ["guardrailConfig"]
396-
total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
394+
total_supported_params = (
395+
supported_converse_params
396+
+ supported_tool_call_params
397+
+ supported_guardrail_params
398+
)
397399
inference_params.pop("json_mode", None) # used for handling json_schema
398400

399401
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
400-
additional_request_params = {k: v for k, v in inference_params.items() if k not in total_supported_params}
401-
inference_params = {k: v for k, v in inference_params.items() if k in total_supported_params}
402+
additional_request_params = {
403+
k: v for k, v in inference_params.items() if k not in total_supported_params
404+
}
405+
inference_params = {
406+
k: v for k, v in inference_params.items() if k in total_supported_params
407+
}
402408

403409
# Only set the topK value in for models that support it
404-
additional_request_params.update(self._handle_top_k_value(model, inference_params))
410+
additional_request_params.update(
411+
self._handle_top_k_value(model, inference_params)
412+
)
405413

406414
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
407415
inference_params.pop("tools", [])
@@ -679,41 +687,6 @@ def _transform_response(
679687

680688
return model_response
681689

682-
def _supported_cross_region_inference_region(self) -> List[str]:
683-
"""
684-
Abbreviations of regions AWS Bedrock supports for cross region inference
685-
"""
686-
return ["us", "eu", "apac"]
687-
688-
def _get_base_model(self, model: str) -> str:
689-
"""
690-
Get the base model from the given model name.
691-
692-
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
693-
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
694-
"""
695-
696-
if model.startswith("bedrock/"):
697-
model = model.split("/", 1)[1]
698-
699-
if model.startswith("converse/"):
700-
model = model.split("/", 1)[1]
701-
702-
potential_region = model.split(".", 1)[0]
703-
704-
alt_potential_region = model.split("/", 1)[
705-
0
706-
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
707-
708-
if potential_region in self._supported_cross_region_inference_region():
709-
return model.split(".", 1)[1]
710-
elif (
711-
alt_potential_region in all_global_regions and len(model.split("/", 1)) > 1
712-
):
713-
return model.split("/", 1)[1]
714-
715-
return model
716-
717690
def get_error_class(
718691
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
719692
) -> BaseLLMException:

litellm/llms/bedrock/common_utils.py

+50
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import httpx
99

1010
import litellm
11+
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
1112
from litellm.llms.base_llm.chat.transformation import BaseLLMException
1213
from litellm.secret_managers.main import get_secret
1314

@@ -310,3 +311,52 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
310311
response_tool_name
311312
]
312313
return response_tool_name
314+
315+
316+
class BedrockModelInfo(BaseLLMModelInfo):
317+
318+
global_config = AmazonBedrockGlobalConfig()
319+
all_global_regions = global_config.get_all_regions()
320+
321+
@staticmethod
322+
def get_base_model(model: str) -> str:
323+
"""
324+
Get the base model from the given model name.
325+
326+
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
327+
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
328+
"""
329+
if model.startswith("bedrock/"):
330+
model = model.split("/", 1)[1]
331+
332+
if model.startswith("converse/"):
333+
model = model.split("/", 1)[1]
334+
335+
if model.startswith("invoke/"):
336+
model = model.split("/", 1)[1]
337+
338+
potential_region = model.split(".", 1)[0]
339+
340+
alt_potential_region = model.split("/", 1)[
341+
0
342+
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
343+
344+
if (
345+
potential_region
346+
in BedrockModelInfo._supported_cross_region_inference_region()
347+
):
348+
return model.split(".", 1)[1]
349+
elif (
350+
alt_potential_region in BedrockModelInfo.all_global_regions
351+
and len(model.split("/", 1)) > 1
352+
):
353+
return model.split("/", 1)[1]
354+
355+
return model
356+
357+
@staticmethod
358+
def _supported_cross_region_inference_region() -> List[str]:
359+
"""
360+
Abbreviations of regions AWS Bedrock supports for cross region inference
361+
"""
362+
return ["us", "eu", "apac"]

litellm/llms/openai/chat/gpt_transformation.py

+4
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
344344
or "https://api.openai.com/v1"
345345
)
346346

347+
@staticmethod
348+
def get_base_model(model: str) -> str:
349+
return model
350+
347351
def get_model_response_iterator(
348352
self,
349353
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],

litellm/llms/topaz/common_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
2929
return (
3030
api_base or get_secret_str("TOPAZ_API_BASE") or "https://api.topazlabs.com"
3131
)
32+
33+
@staticmethod
34+
def get_base_model(model: str) -> str:
35+
return model

litellm/utils.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
calculate_img_tokens,
111111
get_modified_max_tokens,
112112
)
113+
from litellm.llms.bedrock.common_utils import BedrockModelInfo
113114
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
114115
from litellm.router_utils.get_retry_from_policy import (
115116
get_num_retries_from_retry_policy,
@@ -3188,7 +3189,7 @@ def _check_valid_arg(supported_params: List[str]):
31883189
),
31893190
)
31903191
elif custom_llm_provider == "bedrock":
3191-
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
3192+
base_model = BedrockModelInfo.get_base_model(model)
31923193
if base_model in litellm.bedrock_converse_models:
31933194
optional_params = litellm.AmazonConverseConfig().map_openai_params(
31943195
model=model,
@@ -3209,6 +3210,13 @@ def _check_valid_arg(supported_params: List[str]):
32093210
litellm.AmazonAnthropicClaude3Config().map_openai_params(
32103211
non_default_params=non_default_params,
32113212
optional_params=optional_params,
3213+
model=model,
3214+
drop_params=(
3215+
drop_params
3216+
if drop_params is not None
3217+
and isinstance(drop_params, bool)
3218+
else False
3219+
),
32123220
)
32133221
)
32143222
else:
@@ -3971,8 +3979,16 @@ def _strip_stable_vertex_version(model_name) -> str:
39713979
return re.sub(r"-\d+$", "", model_name)
39723980

39733981

3974-
def _strip_bedrock_region(model_name) -> str:
3975-
return litellm.AmazonConverseConfig()._get_base_model(model_name)
3982+
def _get_base_bedrock_model(model_name) -> str:
3983+
"""
3984+
Get the base model from the given model name.
3985+
3986+
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
3987+
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
3988+
"""
3989+
from litellm.llms.bedrock.common_utils import BedrockModelInfo
3990+
3991+
return BedrockModelInfo.get_base_model(model_name)
39763992

39773993

39783994
def _strip_openai_finetune_model_name(model_name: str) -> str:
@@ -3993,8 +4009,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
39934009

39944010
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
39954011
if custom_llm_provider and custom_llm_provider == "bedrock":
3996-
strip_bedrock_region = _strip_bedrock_region(model_name=model)
3997-
return strip_bedrock_region
4012+
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
4013+
return stripped_bedrock_model
39984014
elif custom_llm_provider and (
39994015
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
40004016
):
@@ -6065,7 +6081,7 @@ def get_provider_chat_config( # noqa: PLR0915
60656081
elif litellm.LlmProviders.PETALS == provider:
60666082
return litellm.PetalsConfig()
60676083
elif litellm.LlmProviders.BEDROCK == provider:
6068-
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
6084+
base_model = BedrockModelInfo.get_base_model(model)
60696085
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
60706086
if (
60716087
base_model in litellm.bedrock_converse_models

0 commit comments

Comments
 (0)