Skip to content

Commit

Permalink
Support for Bedrock, and Anthropic
Browse files Browse the repository at this point in the history
Also, code to normalize response to OpenAI format.
  • Loading branch information
rohitprasad15 authored and rohit-rptless committed Sep 5, 2024
1 parent e940359 commit 9c2680e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
39 changes: 39 additions & 0 deletions aisuitealt/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import anthropic
from provider import Provider

class AnthropicProvider(Provider):
def __init__(self, **config):
"""
Initialize the Anthropic provider with the given configuration.
Pass the entire configuration dictionary to the Anthropic client constructor.
"""

self.client = anthropic.Anthropic(**config)

def chat_completions_create(self, model, messages, **kwargs):
# Check if the fist message is a system message
if messages[0]["role"] == "system":
system_message = messages[0]["content"]
messages = messages[1:]
else:
system_message = None

return self.normalize_response(self.client.messages.create(
model=model,
system=system_message,
messages=messages,
**kwargs
))

def normalize_response(self, response):
""" Normalize the response from the Anthropic API to match OpenAI's response format. """
return {
"choices": [
{
"message": {
"role": response.get("role", "assistant"),
"content": response.get("content", ""),
}
}
]
}
46 changes: 42 additions & 4 deletions aisuitealt/providers/aws_bedrock_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from anthropic import AnthropicBedrock
import boto3
from provider import Provider, LLMError

class AWSBedrockProvider(Provider):
Expand All @@ -10,13 +10,51 @@ def __init__(self, **config):
# Anthropic Bedrock client will use the default AWS credential providers, such as
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
# Any overrides from the user is passed to the constructor.
self.client = AnthropicBedrock(**config)
self.client = boto3.client("bedrock-runtime", **config)
# Maintain a list of Inference Parameters which Bedrock supports.
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html
self.inference_parameters = ['maxTokens', 'temperature', 'topP', 'stopSequences']

def normalize_response(self, response):
"""Normalize the response from the Bedrock API to match OpenAI's response format."""
return {
"choices": [
{
"message": {
"content": response["output"]["message"]["content"] if response["output"].get("message") else "",
"role": "assistant"
},
}
]
}

def chat_completions_create(self, model, messages, **kwargs):
# Any exception raised by Anthropic will be returned to the caller.
# Maybe we should catch them and raise a custom LLMError.
return self.client.messages.create(
system_message = None
if messages[0]["role"] == "system":
system_message = [messages[0]["content"]]
messages = messages[1:]

# Maintain a list of Inference Parameters which Bedrock supports.
# These fields need to be passed using inferenceConfig.
# Rest all other fields are passed as additionalModelRequestFields.
inference_config = {}
additional_model_request_fields = {}

# Iterate over the kwargs and separate the inference parameters and additional model request fields.
for key, value in kwargs.items():
if key in self.inference_parameters:
inference_config[key] = value
else:
additional_model_request_fields[key] = value

# Call the Bedrock Converse API.
response = self.client.converse(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the Anthropic API. Eg: max_tokens.
system=system_message,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_request_fields
)
return self.normalize_response(response)

0 comments on commit 9c2680e

Please sign in to comment.