1
- from anthropic import AnthropicBedrock
1
+ import boto3
2
2
from provider import Provider , LLMError
3
3
4
4
class AWSBedrockProvider (Provider ):
@@ -10,13 +10,51 @@ def __init__(self, **config):
10
10
# Anthropic Bedrock client will use the default AWS credential providers, such as
11
11
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
12
12
# Any overrides from the user is passed to the constructor.
13
- self .client = AnthropicBedrock (** config )
13
+ self .client = boto3 .client ("bedrock-runtime" , ** config )
14
+ # Maintain a list of Inference Parameters which Bedrock supports.
15
+ # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html
16
+ self .inference_parameters = ['maxTokens' , 'temperature' , 'topP' , 'stopSequences' ]
17
+
18
+ def normalize_response (self , response ):
19
+ """Normalize the response from the Bedrock API to match OpenAI's response format."""
20
+ return {
21
+ "choices" : [
22
+ {
23
+ "message" : {
24
+ "content" : response ["output" ]["message" ]["content" ] if response ["output" ].get ("message" ) else "" ,
25
+ "role" : "assistant"
26
+ },
27
+ }
28
+ ]
29
+ }
14
30
15
31
def chat_completions_create (self , model , messages , ** kwargs ):
16
32
# Any exception raised by Anthropic will be returned to the caller.
17
33
# Maybe we should catch them and raise a custom LLMError.
18
- return self .client .messages .create (
34
+ system_message = None
35
+ if messages [0 ]["role" ] == "system" :
36
+ system_message = [messages [0 ]["content" ]]
37
+ messages = messages [1 :]
38
+
39
+ # Maintain a list of Inference Parameters which Bedrock supports.
40
+ # These fields need to be passed using inferenceConfig.
41
+ # Rest all other fields are passed as additionalModelRequestFields.
42
+ inference_config = {}
43
+ additional_model_request_fields = {}
44
+
45
+ # Iterate over the kwargs and separate the inference parameters and additional model request fields.
46
+ for key , value in kwargs .items ():
47
+ if key in self .inference_parameters :
48
+ inference_config [key ] = value
49
+ else :
50
+ additional_model_request_fields [key ] = value
51
+
52
+ # Call the Bedrock Converse API.
53
+ response = self .client .converse (
19
54
model = model ,
20
55
messages = messages ,
21
- ** kwargs # Pass any additional arguments to the Anthropic API. Eg: max_tokens.
56
+ system = system_message ,
57
+ inferenceConfig = inference_config ,
58
+ additionalModelRequestFields = additional_model_request_fields
22
59
)
60
+ return self .normalize_response (response )
0 commit comments