Skip to content

Commit 5236a96

Browse files
committed
Addressing review comments.
1 parent 510b881 commit 5236a96

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

aisuite/client.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@ def __init__(self, provider_configs: dict = {}):
66
"""
77
Initialize the client with provider configurations.
88
Use the ProviderFactory to create provider instances.
9+
10+
Args:
11+
provider_configs (dict): A dictionary containing provider configurations.
12+
Each key should be a ProviderNames enum or its string representation,
13+
and the value should be a dictionary of configuration options for that provider.
14+
For example:
15+
{
16+
ProviderNames.OPENAI: {"api_key": "your_openai_api_key"},
17+
"aws-bedrock": {
18+
"aws_access_key": "your_aws_access_key",
19+
"aws_secret_key": "your_aws_secret_key",
20+
"aws_region": "us-west-2"
21+
}
22+
}
923
"""
1024
self.providers = {}
1125
self.provider_configs = provider_configs
@@ -84,8 +98,11 @@ def create(self, model: str, messages: list, **kwargs):
8498
provider_key, model_name = model.split(":", 1)
8599

86100
if provider_key not in ProviderNames._value2member_map_:
101+
# If the provider key does not match, give a clearer message to guide the user
102+
valid_providers = ", ".join([p.value for p in ProviderNames])
87103
raise ValueError(
88-
f"Provider {provider_key} is not a valid ProviderNames enum"
104+
f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. "
105+
"Make sure the model string is formatted correctly as 'provider:model'."
89106
)
90107

91108
if provider_key not in self.client.providers:

aisuite/providers/aws_bedrock_provider.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,35 @@
33
from aisuite.framework import ChatCompletionResponse
44

55

6-
# Used to call the AWS Bedrock converse API
7-
# Converse API provides consistent API, that works with all Amazon Bedrock models that support messages.
8-
# Eg: anthropic.claude-v2,
9-
# meta.llama3-70b-instruct-v1:0,
10-
# mistral.mixtral-8x7b-instruct-v0:1
11-
# The model value can be a baseModelId or provisionedModelArn.
12-
# Using a base model id gives on-demand throughput.
13-
# Use CreateProvisionedModelThroughput API to get provisionedModelArn for higher throughput.
14-
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
156
class AWSBedrockProvider(Provider):
167
def __init__(self, **config):
178
"""
189
Initialize the AWS Bedrock provider with the given configuration.
19-
Pass the entire configuration dictionary to the Anthropic Bedrock client constructor.
10+
11+
This class uses the AWS Bedrock converse API, which provides a consistent interface
12+
for all Amazon Bedrock models that support messages. Examples include:
13+
- anthropic.claude-v2
14+
- meta.llama3-70b-instruct-v1:0
15+
- mistral.mixtral-8x7b-instruct-v0:1
16+
17+
The model value can be a baseModelId for on-demand throughput or a provisionedModelArn
18+
for higher throughput. To obtain a provisionedModelArn, use the CreateProvisionedModelThroughput API.
19+
20+
For more information on model IDs, see:
21+
https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
22+
23+
Note:
24+
- The Anthropic Bedrock client uses default AWS credential providers, such as
25+
~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
26+
- If the region is not set, it defaults to us-west-1, which may lead to a
27+
"Could not connect to the endpoint URL" error.
28+
- The client constructor does not accept additional parameters.
29+
30+
Args:
31+
**config: Configuration options for the provider.
32+
2033
"""
21-
# Anthropic Bedrock client will use the default AWS credential providers, such as
22-
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
23-
# If region is not set, it will use a default to us-west-1 which can lead to error -
24-
# "Could not connect to the endpoint URL"
25-
# It does not like parameters passed to the constructor.
2634
self.client = boto3.client("bedrock-runtime", region_name="us-west-2")
27-
# Maintain a list of Inference Parameters which Bedrock supports.
28-
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html
2935
self.inference_parameters = [
3036
"maxTokens",
3137
"temperature",

0 commit comments

Comments
 (0)