Skip to content

Commit 406e644

Browse files
rohitprasad15rohit-rptless
authored andcommitted
Refactoring the codebase.
This change is a WorkInProgress. An alternate way of writing the unified llm library. - ProviderFactory provides a specific provider impl. - Lazily import the provider based on config passed to Client. - Use Enums to restrict the supported providers, but not restricting the model names supported by the provider. TODO: - Support passing extra_parameters to model in create() - Run linter on the whole code for style conformance - Actual implementation & test case for each provider
1 parent 00105ae commit 406e644

File tree

9 files changed

+200
-0
lines changed

9 files changed

+200
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
.vscode/
33
__pycache__/
4+
env/

aisuitealt/__init__.py

Whitespace-only changes.

aisuitealt/client.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from provider import ProviderFactory, ProviderNames
2+
3+
4+
class Client:
5+
def __init__(self, provider_configs: dict):
6+
"""
7+
Initialize the client with provider configurations.
8+
Use the ProviderFactory to create provider instances.
9+
"""
10+
self.providers = {}
11+
for provider_key, config in provider_configs.items():
12+
# Check if the provider key is a valid ProviderNames enum
13+
if not isinstance(provider_key, ProviderNames):
14+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
15+
# Store the value of the enum in the providers dictionary
16+
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)
17+
18+
self._chat = None
19+
20+
@property
21+
def chat(self):
22+
"""Return the chat API interface."""
23+
if not self._chat:
24+
self._chat = Chat(self)
25+
return self._chat
26+
27+
28+
class Chat:
29+
def __init__(self, client: 'Client'):
30+
self.client = client
31+
32+
@property
33+
def completions(self):
34+
"""Return the completions interface."""
35+
return Completions(self.client)
36+
37+
38+
class Completions:
39+
def __init__(self, client: 'Client'):
40+
self.client = client
41+
42+
def create(self, model: str, messages: list):
43+
"""Create chat completion based on the model."""
44+
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
45+
provider_key, model_name = model.split(":", 1)
46+
47+
# Use the correct provider instance created by the factory
48+
provider = self.client.providers.get(provider_key)
49+
if not provider:
50+
# Add the providers to the ValueError
51+
raise ValueError(f"Provider {provider_key} is not present in the client. Here are the providers: {self.client.providers}")
52+
53+
# Delegate the chat completion to the correct provider's implementation
54+
return provider.chat_completions_create(model_name, messages)
55+
56+

aisuitealt/provider.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
import importlib
4+
5+
class LLMError(Exception):
6+
"""Custom exception for LLM errors."""
7+
def __init__(self, message):
8+
super().__init__(message)
9+
10+
11+
class Provider(ABC):
12+
@abstractmethod
13+
def chat_completions_create(self, model, messages):
14+
"""Abstract method for chat completion calls, to be implemented by each provider."""
15+
pass
16+
17+
18+
class ProviderNames(Enum):
19+
OPENAI = 'openai'
20+
AWS_BEDROCK = 'aws-bedrock'
21+
22+
23+
class ProviderFactory:
24+
"""Factory to register and create provider instances based on keys."""
25+
26+
_provider_modules = {
27+
ProviderNames.OPENAI: 'providers.openai_provider',
28+
ProviderNames.AWS_BEDROCK: 'providers.aws_bedrock_provider',
29+
}
30+
31+
@classmethod
32+
def create_provider(cls, provider_key, config):
33+
"""Dynamically import and create an instance of a provider based on the provider key."""
34+
if not isinstance(provider_key, ProviderNames):
35+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
36+
37+
module_name = cls._provider_modules.get(provider_key)
38+
if not module_name:
39+
raise ValueError(f"Provider {provider_key.value} is not supported")
40+
41+
# Lazily load the module
42+
try:
43+
module = importlib.import_module(module_name)
44+
except ImportError as e:
45+
raise ImportError(f"Could not import module {module_name}: {str(e)}")
46+
47+
# Instantiate the provider class
48+
provider_class = getattr(module, cls._get_provider_class_name(provider_key))
49+
return provider_class(**config)
50+
51+
@staticmethod
52+
def _get_provider_class_name(provider_key):
53+
"""Map provider key to the corresponding class name in the module."""
54+
return {
55+
ProviderNames.OPENAI: 'OpenAIProvider',
56+
ProviderNames.AWS_BEDROCK: 'AWSBedrockProvider',
57+
}[provider_key]

aisuitealt/providers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty __init__.py
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from anthropic import AnthropicBedrock
2+
from provider import Provider, LLMError
3+
4+
class AWSBedrockProvider(Provider):
5+
def __init__(self, access_key, secret_key, session_token, region):
6+
self.client = AnthropicBedrock(
7+
aws_access_key=access_key,
8+
aws_secret_key=secret_key,
9+
aws_session_token=session_token,
10+
aws_region=region
11+
)
12+
13+
def chat_completions_create(self, model, messages):
14+
try:
15+
response = self.client.messages.create(
16+
model=model,
17+
max_tokens=256,
18+
messages=messages
19+
)
20+
return response['choices'][0]['message']['content']
21+
except Exception as e:
22+
raise LLMError(f"AWS Bedrock API error: {str(e)}")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import openai
2+
from provider import Provider, LLMError
3+
4+
class OpenAIProvider(Provider):
5+
def __init__(self, api_key):
6+
openai.api_key = api_key
7+
8+
def chat_completions_create(self, model, messages):
9+
try:
10+
response = openai.ChatCompletion.create(
11+
model=model,
12+
messages=messages
13+
)
14+
return response['choices'][0]['message']['content']
15+
except Exception as e:
16+
raise LLMError(f"OpenAI API error: {str(e)}")

aisuitealt/tests/__init__.py

Whitespace-only changes.

aisuitealt/tests/test_client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
from unittest.mock import patch
3+
from client import Client
4+
from provider import ProviderNames
5+
6+
class TestClient(unittest.TestCase):
7+
8+
@patch('providers.openai_provider.OpenAIProvider.chat_completions_create')
9+
@patch('providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create')
10+
def test_client_chat_completions(self, mock_bedrock, mock_openai):
11+
# Mock responses from providers
12+
mock_openai.return_value = "OpenAI Response"
13+
mock_bedrock.return_value = "AWS Bedrock Response"
14+
15+
# Provider configurations
16+
provider_configs = {
17+
ProviderNames.OPENAI: {
18+
'api_key': 'test_openai_api_key'
19+
},
20+
ProviderNames.AWS_BEDROCK: {
21+
'access_key': 'test_aws_access_key',
22+
'secret_key': 'test_aws_secret_key',
23+
'session_token': 'test_aws_session_token',
24+
'region': 'us-west-2'
25+
}
26+
}
27+
28+
# Initialize the client
29+
client = Client(provider_configs)
30+
31+
# Test OpenAI model
32+
openai_response = client.chat.completions.create(ProviderNames.OPENAI.value + ":" + "gpt-4o", [
33+
{"role": "system", "content": "You are a helpful assistant."},
34+
{"role": "user", "content": "Who won the world series in 2020?"}
35+
])
36+
self.assertEqual(openai_response, "OpenAI Response")
37+
mock_openai.assert_called_once()
38+
39+
# Test AWS Bedrock model
40+
bedrock_response = client.chat.completions.create(ProviderNames.AWS_BEDROCK.value + ":" + "claude-v3", [
41+
{"role": "user", "content": "Hello, world!"}
42+
])
43+
self.assertEqual(bedrock_response, "AWS Bedrock Response")
44+
mock_bedrock.assert_called_once()
45+
46+
if __name__ == '__main__':
47+
unittest.main()

0 commit comments

Comments
 (0)