diff --git a/aisuitealt/providers/anthropic_provider.py b/aisuitealt/providers/anthropic_provider.py new file mode 100644 index 00000000..64fc82fe --- /dev/null +++ b/aisuitealt/providers/anthropic_provider.py @@ -0,0 +1,39 @@ +import anthropic +from provider import Provider + +class AnthropicProvider(Provider): + def __init__(self, **config): + """ + Initialize the Anthropic provider with the given configuration. + Pass the entire configuration dictionary to the Anthropic client constructor. + """ + + self.client = anthropic.Anthropic(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Check if the fist message is a system message + if messages[0]["role"] == "system": + system_message = messages[0]["content"] + messages = messages[1:] + else: + system_message = None + + return self.normalize_response(self.client.messages.create( + model=model, + system=system_message, + messages=messages, + **kwargs + )) + + def normalize_response(self, response): + """ Normalize the response from the Anthropic API to match OpenAI's response format. """ + return { + "choices": [ + { + "message": { + "role": response.get("role", "assistant"), + "content": response.get("content", ""), + } + } + ] + } \ No newline at end of file diff --git a/aisuitealt/providers/aws_bedrock_provider.py b/aisuitealt/providers/aws_bedrock_provider.py index 1d7dd58d..cc2bc843 100644 --- a/aisuitealt/providers/aws_bedrock_provider.py +++ b/aisuitealt/providers/aws_bedrock_provider.py @@ -1,4 +1,4 @@ -from anthropic import AnthropicBedrock +import boto3 from provider import Provider, LLMError class AWSBedrockProvider(Provider): @@ -10,13 +10,51 @@ def __init__(self, **config): # Anthropic Bedrock client will use the default AWS credential providers, such as # using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. # Any overrides from the user is passed to the constructor. - self.client = AnthropicBedrock(**config) + self.client = boto3.client("bedrock-runtime", **config) + # Maintain a list of Inference Parameters which Bedrock supports. + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + self.inference_parameters = ['maxTokens', 'temperature', 'topP', 'stopSequences'] + + def normalize_response(self, response): + """Normalize the response from the Bedrock API to match OpenAI's response format.""" + return { + "choices": [ + { + "message": { + "content": response["output"]["message"]["content"] if response["output"].get("message") else "", + "role": "assistant" + }, + } + ] + } def chat_completions_create(self, model, messages, **kwargs): # Any exception raised by Anthropic will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. - return self.client.messages.create( + system_message = None + if messages[0]["role"] == "system": + system_message = [messages[0]["content"]] + messages = messages[1:] + + # Maintain a list of Inference Parameters which Bedrock supports. + # These fields need to be passed using inferenceConfig. + # Rest all other fields are passed as additionalModelRequestFields. + inference_config = {} + additional_model_request_fields = {} + + # Iterate over the kwargs and separate the inference parameters and additional model request fields. + for key, value in kwargs.items(): + if key in self.inference_parameters: + inference_config[key] = value + else: + additional_model_request_fields[key] = value + + # Call the Bedrock Converse API. + response = self.client.converse( model=model, messages=messages, - **kwargs # Pass any additional arguments to the Anthropic API. Eg: max_tokens. + system=system_message, + inferenceConfig=inference_config, + additionalModelRequestFields=additional_model_request_fields ) + return self.normalize_response(response)