Skip to content

Commit

Permalink
Refactoring the codebase.
Browse files Browse the repository at this point in the history
This change is a WorkInProgress.

An alternate way of writing the unified llm library.
- ProviderFactory provides a specific provider impl.
- Lazily import the provider based on config passed to Client.
- Use Enums to restrict the supported providers, but not
  restricting the model names supported by the provider.
TODO:
- Support passing extra_parameters to model in create()
- Run linter on the whole code for style conformance
- Actual implementation & test case for each provider
  • Loading branch information
rohitprasad15 authored and rohit-rptless committed Sep 5, 2024
1 parent 00105ae commit 406e644
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
.vscode/
__pycache__/
env/
Empty file added aisuitealt/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions aisuitealt/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from provider import ProviderFactory, ProviderNames


class Client:
def __init__(self, provider_configs: dict):
"""
Initialize the client with provider configurations.
Use the ProviderFactory to create provider instances.
"""
self.providers = {}
for provider_key, config in provider_configs.items():
# Check if the provider key is a valid ProviderNames enum
if not isinstance(provider_key, ProviderNames):
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
# Store the value of the enum in the providers dictionary
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)

self._chat = None

@property
def chat(self):
"""Return the chat API interface."""
if not self._chat:
self._chat = Chat(self)
return self._chat


class Chat:
def __init__(self, client: 'Client'):
self.client = client

@property
def completions(self):
"""Return the completions interface."""
return Completions(self.client)


class Completions:
def __init__(self, client: 'Client'):
self.client = client

def create(self, model: str, messages: list):
"""Create chat completion based on the model."""
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
provider_key, model_name = model.split(":", 1)

# Use the correct provider instance created by the factory
provider = self.client.providers.get(provider_key)
if not provider:
# Add the providers to the ValueError
raise ValueError(f"Provider {provider_key} is not present in the client. Here are the providers: {self.client.providers}")

# Delegate the chat completion to the correct provider's implementation
return provider.chat_completions_create(model_name, messages)


57 changes: 57 additions & 0 deletions aisuitealt/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from enum import Enum
import importlib

class LLMError(Exception):
"""Custom exception for LLM errors."""
def __init__(self, message):
super().__init__(message)


class Provider(ABC):
@abstractmethod
def chat_completions_create(self, model, messages):
"""Abstract method for chat completion calls, to be implemented by each provider."""
pass


class ProviderNames(Enum):
OPENAI = 'openai'
AWS_BEDROCK = 'aws-bedrock'


class ProviderFactory:
"""Factory to register and create provider instances based on keys."""

_provider_modules = {
ProviderNames.OPENAI: 'providers.openai_provider',
ProviderNames.AWS_BEDROCK: 'providers.aws_bedrock_provider',
}

@classmethod
def create_provider(cls, provider_key, config):
"""Dynamically import and create an instance of a provider based on the provider key."""
if not isinstance(provider_key, ProviderNames):
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")

module_name = cls._provider_modules.get(provider_key)
if not module_name:
raise ValueError(f"Provider {provider_key.value} is not supported")

# Lazily load the module
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise ImportError(f"Could not import module {module_name}: {str(e)}")

# Instantiate the provider class
provider_class = getattr(module, cls._get_provider_class_name(provider_key))
return provider_class(**config)

@staticmethod
def _get_provider_class_name(provider_key):
"""Map provider key to the corresponding class name in the module."""
return {
ProviderNames.OPENAI: 'OpenAIProvider',
ProviderNames.AWS_BEDROCK: 'AWSBedrockProvider',
}[provider_key]
1 change: 1 addition & 0 deletions aisuitealt/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Empty __init__.py
22 changes: 22 additions & 0 deletions aisuitealt/providers/aws_bedrock_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from anthropic import AnthropicBedrock
from provider import Provider, LLMError

class AWSBedrockProvider(Provider):
def __init__(self, access_key, secret_key, session_token, region):
self.client = AnthropicBedrock(
aws_access_key=access_key,
aws_secret_key=secret_key,
aws_session_token=session_token,
aws_region=region
)

def chat_completions_create(self, model, messages):
try:
response = self.client.messages.create(
model=model,
max_tokens=256,
messages=messages
)
return response['choices'][0]['message']['content']
except Exception as e:
raise LLMError(f"AWS Bedrock API error: {str(e)}")
16 changes: 16 additions & 0 deletions aisuitealt/providers/openai_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import openai
from provider import Provider, LLMError

class OpenAIProvider(Provider):
def __init__(self, api_key):
openai.api_key = api_key

def chat_completions_create(self, model, messages):
try:
response = openai.ChatCompletion.create(
model=model,
messages=messages
)
return response['choices'][0]['message']['content']
except Exception as e:
raise LLMError(f"OpenAI API error: {str(e)}")
Empty file added aisuitealt/tests/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions aisuitealt/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest
from unittest.mock import patch
from client import Client
from provider import ProviderNames

class TestClient(unittest.TestCase):

@patch('providers.openai_provider.OpenAIProvider.chat_completions_create')
@patch('providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create')
def test_client_chat_completions(self, mock_bedrock, mock_openai):
# Mock responses from providers
mock_openai.return_value = "OpenAI Response"
mock_bedrock.return_value = "AWS Bedrock Response"

# Provider configurations
provider_configs = {
ProviderNames.OPENAI: {
'api_key': 'test_openai_api_key'
},
ProviderNames.AWS_BEDROCK: {
'access_key': 'test_aws_access_key',
'secret_key': 'test_aws_secret_key',
'session_token': 'test_aws_session_token',
'region': 'us-west-2'
}
}

# Initialize the client
client = Client(provider_configs)

# Test OpenAI model
openai_response = client.chat.completions.create(ProviderNames.OPENAI.value + ":" + "gpt-4o", [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"}
])
self.assertEqual(openai_response, "OpenAI Response")
mock_openai.assert_called_once()

# Test AWS Bedrock model
bedrock_response = client.chat.completions.create(ProviderNames.AWS_BEDROCK.value + ":" + "claude-v3", [
{"role": "user", "content": "Hello, world!"}
])
self.assertEqual(bedrock_response, "AWS Bedrock Response")
mock_bedrock.assert_called_once()

if __name__ == '__main__':
unittest.main()

0 comments on commit 406e644

Please sign in to comment.