36
36
from ..common_utils import (
37
37
AmazonBedrockGlobalConfig ,
38
38
BedrockError ,
39
+ BedrockModelInfo ,
39
40
get_bedrock_tool_name ,
40
41
)
41
42
42
- global_config = AmazonBedrockGlobalConfig ()
43
- all_global_regions = global_config .get_all_regions ()
44
-
45
43
46
44
class AmazonConverseConfig (BaseConfig ):
47
45
"""
@@ -104,7 +102,7 @@ def get_supported_openai_params(self, model: str) -> List[str]:
104
102
]
105
103
106
104
## Filter out 'cross-region' from model name
107
- base_model = self . _get_base_model (model )
105
+ base_model = BedrockModelInfo . get_base_model (model )
108
106
109
107
if (
110
108
base_model .startswith ("anthropic" )
@@ -341,9 +339,9 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
341
339
if "top_k" in inference_params :
342
340
inference_params ["topK" ] = inference_params .pop ("top_k" )
343
341
return InferenceConfig (** inference_params )
344
-
342
+
345
343
def _handle_top_k_value (self , model : str , inference_params : dict ) -> dict :
346
- base_model = self . _get_base_model (model )
344
+ base_model = BedrockModelInfo . get_base_model (model )
347
345
348
346
val_top_k = None
349
347
if "topK" in inference_params :
@@ -352,11 +350,11 @@ def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
352
350
val_top_k = inference_params .pop ("top_k" )
353
351
354
352
if val_top_k :
355
- if ( base_model .startswith ("anthropic" ) ):
353
+ if base_model .startswith ("anthropic" ):
356
354
return {"top_k" : val_top_k }
357
355
if base_model .startswith ("amazon.nova" ):
358
- return {' inferenceConfig' : {"topK" : val_top_k }}
359
-
356
+ return {" inferenceConfig" : {"topK" : val_top_k }}
357
+
360
358
return {}
361
359
362
360
def _transform_request_helper (
@@ -393,15 +391,25 @@ def _transform_request_helper(
393
391
) + ["top_k" ]
394
392
supported_tool_call_params = ["tools" , "tool_choice" ]
395
393
supported_guardrail_params = ["guardrailConfig" ]
396
- total_supported_params = supported_converse_params + supported_tool_call_params + supported_guardrail_params
394
+ total_supported_params = (
395
+ supported_converse_params
396
+ + supported_tool_call_params
397
+ + supported_guardrail_params
398
+ )
397
399
inference_params .pop ("json_mode" , None ) # used for handling json_schema
398
400
399
401
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
400
- additional_request_params = {k : v for k , v in inference_params .items () if k not in total_supported_params }
401
- inference_params = {k : v for k , v in inference_params .items () if k in total_supported_params }
402
+ additional_request_params = {
403
+ k : v for k , v in inference_params .items () if k not in total_supported_params
404
+ }
405
+ inference_params = {
406
+ k : v for k , v in inference_params .items () if k in total_supported_params
407
+ }
402
408
403
409
# Only set the topK value in for models that support it
404
- additional_request_params .update (self ._handle_top_k_value (model , inference_params ))
410
+ additional_request_params .update (
411
+ self ._handle_top_k_value (model , inference_params )
412
+ )
405
413
406
414
bedrock_tools : List [ToolBlock ] = _bedrock_tools_pt (
407
415
inference_params .pop ("tools" , [])
@@ -679,41 +687,6 @@ def _transform_response(
679
687
680
688
return model_response
681
689
682
- def _supported_cross_region_inference_region (self ) -> List [str ]:
683
- """
684
- Abbreviations of regions AWS Bedrock supports for cross region inference
685
- """
686
- return ["us" , "eu" , "apac" ]
687
-
688
- def _get_base_model (self , model : str ) -> str :
689
- """
690
- Get the base model from the given model name.
691
-
692
- Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
693
- AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
694
- """
695
-
696
- if model .startswith ("bedrock/" ):
697
- model = model .split ("/" , 1 )[1 ]
698
-
699
- if model .startswith ("converse/" ):
700
- model = model .split ("/" , 1 )[1 ]
701
-
702
- potential_region = model .split ("." , 1 )[0 ]
703
-
704
- alt_potential_region = model .split ("/" , 1 )[
705
- 0
706
- ] # in model cost map we store regional information like `/us-west-2/bedrock-model`
707
-
708
- if potential_region in self ._supported_cross_region_inference_region ():
709
- return model .split ("." , 1 )[1 ]
710
- elif (
711
- alt_potential_region in all_global_regions and len (model .split ("/" , 1 )) > 1
712
- ):
713
- return model .split ("/" , 1 )[1 ]
714
-
715
- return model
716
-
717
690
def get_error_class (
718
691
self , error_message : str , status_code : int , headers : Union [dict , httpx .Headers ]
719
692
) -> BaseLLMException :
0 commit comments