From 5236a9674bd9900369f3141bfd873da70e86e43f Mon Sep 17 00:00:00 2001 From: rohit-rptless Date: Thu, 12 Sep 2024 11:47:03 -0700 Subject: [PATCH] Addressing review comments. --- aisuite/client.py | 19 ++++++++++- aisuite/providers/aws_bedrock_provider.py | 40 +++++++++++++---------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/aisuite/client.py b/aisuite/client.py index f6474096..326a83c0 100644 --- a/aisuite/client.py +++ b/aisuite/client.py @@ -6,6 +6,20 @@ def __init__(self, provider_configs: dict = {}): """ Initialize the client with provider configurations. Use the ProviderFactory to create provider instances. + + Args: + provider_configs (dict): A dictionary containing provider configurations. + Each key should be a ProviderNames enum or its string representation, + and the value should be a dictionary of configuration options for that provider. + For example: + { + ProviderNames.OPENAI: {"api_key": "your_openai_api_key"}, + "aws-bedrock": { + "aws_access_key": "your_aws_access_key", + "aws_secret_key": "your_aws_secret_key", + "aws_region": "us-west-2" + } + } """ self.providers = {} self.provider_configs = provider_configs @@ -84,8 +98,11 @@ def create(self, model: str, messages: list, **kwargs): provider_key, model_name = model.split(":", 1) if provider_key not in ProviderNames._value2member_map_: + # If the provider key does not match, give a clearer message to guide the user + valid_providers = ", ".join([p.value for p in ProviderNames]) raise ValueError( - f"Provider {provider_key} is not a valid ProviderNames enum" + f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." ) if provider_key not in self.client.providers: diff --git a/aisuite/providers/aws_bedrock_provider.py b/aisuite/providers/aws_bedrock_provider.py index f9cc6979..c6991c82 100644 --- a/aisuite/providers/aws_bedrock_provider.py +++ b/aisuite/providers/aws_bedrock_provider.py @@ -3,29 +3,35 @@ from aisuite.framework import ChatCompletionResponse -# Used to call the AWS Bedrock converse API -# Converse API provides consistent API, that works with all Amazon Bedrock models that support messages. -# Eg: anthropic.claude-v2, -# meta.llama3-70b-instruct-v1:0, -# mistral.mixtral-8x7b-instruct-v0:1 -# The model value can be a baseModelId or provisionedModelArn. -# Using a base model id gives on-demand throughput. -# Use CreateProvisionedModelThroughput API to get provisionedModelArn for higher throughput. -# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html class AWSBedrockProvider(Provider): def __init__(self, **config): """ Initialize the AWS Bedrock provider with the given configuration. - Pass the entire configuration dictionary to the Anthropic Bedrock client constructor. + + This class uses the AWS Bedrock converse API, which provides a consistent interface + for all Amazon Bedrock models that support messages. Examples include: + - anthropic.claude-v2 + - meta.llama3-70b-instruct-v1:0 + - mistral.mixtral-8x7b-instruct-v0:1 + + The model value can be a baseModelId for on-demand throughput or a provisionedModelArn + for higher throughput. To obtain a provisionedModelArn, use the CreateProvisionedModelThroughput API. + + For more information on model IDs, see: + https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html + + Note: + - The Anthropic Bedrock client uses default AWS credential providers, such as + ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. + - If the region is not set, it defaults to us-west-1, which may lead to a + "Could not connect to the endpoint URL" error. + - The client constructor does not accept additional parameters. + + Args: + **config: Configuration options for the provider. + """ - # Anthropic Bedrock client will use the default AWS credential providers, such as - # using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. - # If region is not set, it will use a default to us-west-1 which can lead to error - - # "Could not connect to the endpoint URL" - # It does not like parameters passed to the constructor. self.client = boto3.client("bedrock-runtime", region_name="us-west-2") - # Maintain a list of Inference Parameters which Bedrock supports. - # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html self.inference_parameters = [ "maxTokens", "temperature",