Skip to content

Commit e940359

Browse files
committed
Added .configure() for passing configs instead of using ctor.
1 parent 6b312eb commit e940359

File tree

4 files changed

+85
-40
lines changed

4 files changed

+85
-40
lines changed

aisuitealt/client.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33

44
class Client:
5-
def __init__(self, provider_configs: dict):
5+
def __init__(self, provider_configs: dict = {}):
66
"""
77
Initialize the client with provider configurations.
88
Use the ProviderFactory to create provider instances.
99
"""
1010
self.providers = {}
11+
self.provider_configs = provider_configs
1112
for provider_key, config in provider_configs.items():
1213
# Check if the provider key is a valid ProviderNames enum
1314
if not isinstance(provider_key, ProviderNames):
@@ -17,6 +18,20 @@ def __init__(self, provider_configs: dict):
1718

1819
self._chat = None
1920

21+
def configure(self, provider_configs: dict = None):
22+
"""
23+
Configure the client with provider configurations.
24+
"""
25+
if provider_configs is None:
26+
return
27+
28+
self.provider_configs.update(provider_configs)
29+
30+
for provider_key, config in self.provider_configs.items():
31+
if not isinstance(provider_key, ProviderNames):
32+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
33+
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)
34+
2035
@property
2136
def chat(self):
2237
"""Return the chat API interface."""
@@ -39,18 +54,33 @@ class Completions:
3954
def __init__(self, client: 'Client'):
4055
self.client = client
4156

42-
def create(self, model: str, messages: list):
43-
"""Create chat completion based on the model."""
57+
def create(self, model: str, messages: list, **kwargs):
58+
"""
59+
Create chat completion based on the model, messages, and any extra arguments.
60+
"""
61+
# Check that correct format is used
62+
if ':' not in model:
63+
raise ValueError(f"Invalid model format. Expected 'provider:model', got '{model}'")
64+
4465
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
4566
provider_key, model_name = model.split(":", 1)
4667

47-
# Use the correct provider instance created by the factory
68+
if provider_key not in ProviderNames._value2member_map_:
69+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
70+
71+
if provider_key not in self.client.providers:
72+
config = {}
73+
if provider_key in self.client.provider_configs:
74+
config = self.client.provider_configs[provider_key]
75+
self.client.providers[provider_key] = ProviderFactory.create_provider(ProviderNames(provider_key), config)
76+
4877
provider = self.client.providers.get(provider_key)
4978
if not provider:
50-
# Add the providers to the ValueError
51-
raise ValueError(f"Provider {provider_key} is not present in the client. Here are the providers: {self.client.providers}")
79+
raise ValueError(f"Could not load provider for {provider_key}.")
5280

5381
# Delegate the chat completion to the correct provider's implementation
54-
return provider.chat_completions_create(model_name, messages)
82+
# Any additional arguments will be passed to the provider's implementation.
83+
# Eg: max_tokens, temperature, etc.
84+
return provider.chat_completions_create(model_name, messages, **kwargs)
5585

5686

aisuitealt/providers/aws_bedrock_provider.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
from provider import Provider, LLMError
33

44
class AWSBedrockProvider(Provider):
5-
def __init__(self, access_key, secret_key, session_token, region):
6-
self.client = AnthropicBedrock(
7-
aws_access_key=access_key,
8-
aws_secret_key=secret_key,
9-
aws_session_token=session_token,
10-
aws_region=region
11-
)
5+
def __init__(self, **config):
6+
"""
7+
Initialize the AWS Bedrock provider with the given configuration.
8+
Pass the entire configuration dictionary to the Anthropic Bedrock client constructor.
9+
"""
10+
# Anthropic Bedrock client will use the default AWS credential providers, such as
11+
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
12+
# Any overrides from the user is passed to the constructor.
13+
self.client = AnthropicBedrock(**config)
1214

13-
def chat_completions_create(self, model, messages):
14-
try:
15-
response = self.client.messages.create(
16-
model=model,
17-
max_tokens=256,
18-
messages=messages
19-
)
20-
return response['choices'][0]['message']['content']
21-
except Exception as e:
22-
raise LLMError(f"AWS Bedrock API error: {str(e)}")
15+
def chat_completions_create(self, model, messages, **kwargs):
16+
# Any exception raised by Anthropic will be returned to the caller.
17+
# Maybe we should catch them and raise a custom LLMError.
18+
return self.client.messages.create(
19+
model=model,
20+
messages=messages,
21+
**kwargs # Pass any additional arguments to the Anthropic API. Eg: max_tokens.
22+
)
Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
import openai
2+
import os
23
from provider import Provider, LLMError
34

45
class OpenAIProvider(Provider):
5-
def __init__(self, api_key):
6-
openai.api_key = api_key
6+
def __init__(self, **config):
7+
"""
8+
Initialize the OpenAI provider with the given configuration.
9+
Pass the entire configuration dictionary to the OpenAI client constructor.
10+
"""
11+
# Ensure API key is provided either in config or via environment variable
12+
config.setdefault('api_key', os.getenv('OPENAI_API_KEY'))
13+
if not config['api_key']:
14+
raise ValueError("OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable.")
715

8-
def chat_completions_create(self, model, messages):
9-
try:
10-
response = openai.ChatCompletion.create(
11-
model=model,
12-
messages=messages
13-
)
14-
return response['choices'][0]['message']['content']
15-
except Exception as e:
16-
raise LLMError(f"OpenAI API error: {str(e)}")
16+
# NOTE: We could choose to remove above lines for api_key since OpenAI will automatically
17+
# infer certain values from the environment variables.
18+
# Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc.
19+
20+
# Pass the entire config to the OpenAI client constructor
21+
self.client = openai.OpenAI(**config)
22+
23+
def chat_completions_create(self, model, messages, **kwargs):
24+
# Any exception raised by OpenAI will be returned to the caller.
25+
# Maybe we should catch them and raise a custom LLMError.
26+
return self.client.chat.completions.create(
27+
model=model,
28+
messages=messages,
29+
**kwargs # Pass any additional arguments to the OpenAI API
30+
)

aisuitealt/tests/test_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ def test_client_chat_completions(self, mock_bedrock, mock_openai):
1818
'api_key': 'test_openai_api_key'
1919
},
2020
ProviderNames.AWS_BEDROCK: {
21-
'access_key': 'test_aws_access_key',
22-
'secret_key': 'test_aws_secret_key',
23-
'session_token': 'test_aws_session_token',
24-
'region': 'us-west-2'
21+
'aws_access_key': 'test_aws_access_key',
22+
'aws_secret_key': 'test_aws_secret_key',
23+
'aws_session_token': 'test_aws_session_token',
24+
'aws_region': 'us-west-2'
2525
}
2626
}
2727

2828
# Initialize the client
29-
client = Client(provider_configs)
29+
client = Client()
30+
client.configure(provider_configs)
3031

3132
# Test OpenAI model
3233
open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o"

0 commit comments

Comments
 (0)