-
Notifications
You must be signed in to change notification settings - Fork 901
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change is a WorkInProgress. An alternate way of writing the unified llm library. - ProviderFactory provides a specific provider impl. - Lazily import the provider based on config passed to Client. - Use Enums to restrict the supported providers, but not restricting the model names supported by the provider. TODO: - Support passing extra_parameters to model in create() - Run linter on the whole code for style conformance - Actual implementation & test case for each provider
- Loading branch information
1 parent
00105ae
commit 406e644
Showing
9 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
.idea/ | ||
.vscode/ | ||
__pycache__/ | ||
env/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from provider import ProviderFactory, ProviderNames | ||
|
||
|
||
class Client: | ||
def __init__(self, provider_configs: dict): | ||
""" | ||
Initialize the client with provider configurations. | ||
Use the ProviderFactory to create provider instances. | ||
""" | ||
self.providers = {} | ||
for provider_key, config in provider_configs.items(): | ||
# Check if the provider key is a valid ProviderNames enum | ||
if not isinstance(provider_key, ProviderNames): | ||
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") | ||
# Store the value of the enum in the providers dictionary | ||
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config) | ||
|
||
self._chat = None | ||
|
||
@property | ||
def chat(self): | ||
"""Return the chat API interface.""" | ||
if not self._chat: | ||
self._chat = Chat(self) | ||
return self._chat | ||
|
||
|
||
class Chat: | ||
def __init__(self, client: 'Client'): | ||
self.client = client | ||
|
||
@property | ||
def completions(self): | ||
"""Return the completions interface.""" | ||
return Completions(self.client) | ||
|
||
|
||
class Completions: | ||
def __init__(self, client: 'Client'): | ||
self.client = client | ||
|
||
def create(self, model: str, messages: list): | ||
"""Create chat completion based on the model.""" | ||
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" | ||
provider_key, model_name = model.split(":", 1) | ||
|
||
# Use the correct provider instance created by the factory | ||
provider = self.client.providers.get(provider_key) | ||
if not provider: | ||
# Add the providers to the ValueError | ||
raise ValueError(f"Provider {provider_key} is not present in the client. Here are the providers: {self.client.providers}") | ||
|
||
# Delegate the chat completion to the correct provider's implementation | ||
return provider.chat_completions_create(model_name, messages) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
import importlib | ||
|
||
class LLMError(Exception): | ||
"""Custom exception for LLM errors.""" | ||
def __init__(self, message): | ||
super().__init__(message) | ||
|
||
|
||
class Provider(ABC): | ||
@abstractmethod | ||
def chat_completions_create(self, model, messages): | ||
"""Abstract method for chat completion calls, to be implemented by each provider.""" | ||
pass | ||
|
||
|
||
class ProviderNames(Enum): | ||
OPENAI = 'openai' | ||
AWS_BEDROCK = 'aws-bedrock' | ||
|
||
|
||
class ProviderFactory: | ||
"""Factory to register and create provider instances based on keys.""" | ||
|
||
_provider_modules = { | ||
ProviderNames.OPENAI: 'providers.openai_provider', | ||
ProviderNames.AWS_BEDROCK: 'providers.aws_bedrock_provider', | ||
} | ||
|
||
@classmethod | ||
def create_provider(cls, provider_key, config): | ||
"""Dynamically import and create an instance of a provider based on the provider key.""" | ||
if not isinstance(provider_key, ProviderNames): | ||
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") | ||
|
||
module_name = cls._provider_modules.get(provider_key) | ||
if not module_name: | ||
raise ValueError(f"Provider {provider_key.value} is not supported") | ||
|
||
# Lazily load the module | ||
try: | ||
module = importlib.import_module(module_name) | ||
except ImportError as e: | ||
raise ImportError(f"Could not import module {module_name}: {str(e)}") | ||
|
||
# Instantiate the provider class | ||
provider_class = getattr(module, cls._get_provider_class_name(provider_key)) | ||
return provider_class(**config) | ||
|
||
@staticmethod | ||
def _get_provider_class_name(provider_key): | ||
"""Map provider key to the corresponding class name in the module.""" | ||
return { | ||
ProviderNames.OPENAI: 'OpenAIProvider', | ||
ProviderNames.AWS_BEDROCK: 'AWSBedrockProvider', | ||
}[provider_key] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Empty __init__.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from anthropic import AnthropicBedrock | ||
from provider import Provider, LLMError | ||
|
||
class AWSBedrockProvider(Provider): | ||
def __init__(self, access_key, secret_key, session_token, region): | ||
self.client = AnthropicBedrock( | ||
aws_access_key=access_key, | ||
aws_secret_key=secret_key, | ||
aws_session_token=session_token, | ||
aws_region=region | ||
) | ||
|
||
def chat_completions_create(self, model, messages): | ||
try: | ||
response = self.client.messages.create( | ||
model=model, | ||
max_tokens=256, | ||
messages=messages | ||
) | ||
return response['choices'][0]['message']['content'] | ||
except Exception as e: | ||
raise LLMError(f"AWS Bedrock API error: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import openai | ||
from provider import Provider, LLMError | ||
|
||
class OpenAIProvider(Provider): | ||
def __init__(self, api_key): | ||
openai.api_key = api_key | ||
|
||
def chat_completions_create(self, model, messages): | ||
try: | ||
response = openai.ChatCompletion.create( | ||
model=model, | ||
messages=messages | ||
) | ||
return response['choices'][0]['message']['content'] | ||
except Exception as e: | ||
raise LLMError(f"OpenAI API error: {str(e)}") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
from client import Client | ||
from provider import ProviderNames | ||
|
||
class TestClient(unittest.TestCase): | ||
|
||
@patch('providers.openai_provider.OpenAIProvider.chat_completions_create') | ||
@patch('providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create') | ||
def test_client_chat_completions(self, mock_bedrock, mock_openai): | ||
# Mock responses from providers | ||
mock_openai.return_value = "OpenAI Response" | ||
mock_bedrock.return_value = "AWS Bedrock Response" | ||
|
||
# Provider configurations | ||
provider_configs = { | ||
ProviderNames.OPENAI: { | ||
'api_key': 'test_openai_api_key' | ||
}, | ||
ProviderNames.AWS_BEDROCK: { | ||
'access_key': 'test_aws_access_key', | ||
'secret_key': 'test_aws_secret_key', | ||
'session_token': 'test_aws_session_token', | ||
'region': 'us-west-2' | ||
} | ||
} | ||
|
||
# Initialize the client | ||
client = Client(provider_configs) | ||
|
||
# Test OpenAI model | ||
openai_response = client.chat.completions.create(ProviderNames.OPENAI.value + ":" + "gpt-4o", [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "Who won the world series in 2020?"} | ||
]) | ||
self.assertEqual(openai_response, "OpenAI Response") | ||
mock_openai.assert_called_once() | ||
|
||
# Test AWS Bedrock model | ||
bedrock_response = client.chat.completions.create(ProviderNames.AWS_BEDROCK.value + ":" + "claude-v3", [ | ||
{"role": "user", "content": "Hello, world!"} | ||
]) | ||
self.assertEqual(bedrock_response, "AWS Bedrock Response") | ||
mock_bedrock.assert_called_once() | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |