Skip to content

Commit

Permalink
Addressing review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohit-rptless committed Sep 12, 2024
1 parent 510b881 commit 5236a96
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
19 changes: 18 additions & 1 deletion aisuite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ def __init__(self, provider_configs: dict = {}):
"""
Initialize the client with provider configurations.
Use the ProviderFactory to create provider instances.
Args:
provider_configs (dict): A dictionary containing provider configurations.
Each key should be a ProviderNames enum or its string representation,
and the value should be a dictionary of configuration options for that provider.
For example:
{
ProviderNames.OPENAI: {"api_key": "your_openai_api_key"},
"aws-bedrock": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"aws_region": "us-west-2"
}
}
"""
self.providers = {}
self.provider_configs = provider_configs
Expand Down Expand Up @@ -84,8 +98,11 @@ def create(self, model: str, messages: list, **kwargs):
provider_key, model_name = model.split(":", 1)

if provider_key not in ProviderNames._value2member_map_:
# If the provider key does not match, give a clearer message to guide the user
valid_providers = ", ".join([p.value for p in ProviderNames])
raise ValueError(
f"Provider {provider_key} is not a valid ProviderNames enum"
f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

if provider_key not in self.client.providers:
Expand Down
40 changes: 23 additions & 17 deletions aisuite/providers/aws_bedrock_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,35 @@
from aisuite.framework import ChatCompletionResponse


# Used to call the AWS Bedrock converse API
# Converse API provides consistent API, that works with all Amazon Bedrock models that support messages.
# Eg: anthropic.claude-v2,
# meta.llama3-70b-instruct-v1:0,
# mistral.mixtral-8x7b-instruct-v0:1
# The model value can be a baseModelId or provisionedModelArn.
# Using a base model id gives on-demand throughput.
# Use CreateProvisionedModelThroughput API to get provisionedModelArn for higher throughput.
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
class AWSBedrockProvider(Provider):
def __init__(self, **config):
"""
Initialize the AWS Bedrock provider with the given configuration.
Pass the entire configuration dictionary to the Anthropic Bedrock client constructor.
This class uses the AWS Bedrock converse API, which provides a consistent interface
for all Amazon Bedrock models that support messages. Examples include:
- anthropic.claude-v2
- meta.llama3-70b-instruct-v1:0
- mistral.mixtral-8x7b-instruct-v0:1
The model value can be a baseModelId for on-demand throughput or a provisionedModelArn
for higher throughput. To obtain a provisionedModelArn, use the CreateProvisionedModelThroughput API.
For more information on model IDs, see:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
Note:
- The Anthropic Bedrock client uses default AWS credential providers, such as
~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
- If the region is not set, it defaults to us-west-1, which may lead to a
"Could not connect to the endpoint URL" error.
- The client constructor does not accept additional parameters.
Args:
**config: Configuration options for the provider.
"""
# Anthropic Bedrock client will use the default AWS credential providers, such as
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
# If region is not set, it will use a default to us-west-1 which can lead to error -
# "Could not connect to the endpoint URL"
# It does not like parameters passed to the constructor.
self.client = boto3.client("bedrock-runtime", region_name="us-west-2")
# Maintain a list of Inference Parameters which Bedrock supports.
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html
self.inference_parameters = [
"maxTokens",
"temperature",
Expand Down

0 comments on commit 5236a96

Please sign in to comment.