From 406e64426aa0b3c940c5fe9453fc19f940050859 Mon Sep 17 00:00:00 2001 From: Rohit Prasad Date: Wed, 4 Sep 2024 18:14:39 -0700 Subject: [PATCH] Refactoring the codebase. This change is a WorkInProgress. An alternate way of writing the unified llm library. - ProviderFactory provides a specific provider impl. - Lazily import the provider based on config passed to Client. - Use Enums to restrict the supported providers, but not restricting the model names supported by the provider. TODO: - Support passing extra_parameters to model in create() - Run linter on the whole code for style conformance - Actual implementation & test case for each provider --- .gitignore | 1 + aisuitealt/__init__.py | 0 aisuitealt/client.py | 56 +++++++++++++++++++ aisuitealt/provider.py | 57 ++++++++++++++++++++ aisuitealt/providers/__init__.py | 1 + aisuitealt/providers/aws_bedrock_provider.py | 22 ++++++++ aisuitealt/providers/openai_provider.py | 16 ++++++ aisuitealt/tests/__init__.py | 0 aisuitealt/tests/test_client.py | 47 ++++++++++++++++ 9 files changed, 200 insertions(+) create mode 100644 aisuitealt/__init__.py create mode 100644 aisuitealt/client.py create mode 100644 aisuitealt/provider.py create mode 100644 aisuitealt/providers/__init__.py create mode 100644 aisuitealt/providers/aws_bedrock_provider.py create mode 100644 aisuitealt/providers/openai_provider.py create mode 100644 aisuitealt/tests/__init__.py create mode 100644 aisuitealt/tests/test_client.py diff --git a/.gitignore b/.gitignore index f1f1d58d..e1084c98 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ .vscode/ __pycache__/ +env/ diff --git a/aisuitealt/__init__.py b/aisuitealt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aisuitealt/client.py b/aisuitealt/client.py new file mode 100644 index 00000000..4d645c41 --- /dev/null +++ b/aisuitealt/client.py @@ -0,0 +1,56 @@ +from provider import ProviderFactory, ProviderNames + + +class Client: + def __init__(self, provider_configs: dict): + """ + Initialize the client with provider configurations. + Use the ProviderFactory to create provider instances. + """ + self.providers = {} + for provider_key, config in provider_configs.items(): + # Check if the provider key is a valid ProviderNames enum + if not isinstance(provider_key, ProviderNames): + raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") + # Store the value of the enum in the providers dictionary + self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config) + + self._chat = None + + @property + def chat(self): + """Return the chat API interface.""" + if not self._chat: + self._chat = Chat(self) + return self._chat + + +class Chat: + def __init__(self, client: 'Client'): + self.client = client + + @property + def completions(self): + """Return the completions interface.""" + return Completions(self.client) + + +class Completions: + def __init__(self, client: 'Client'): + self.client = client + + def create(self, model: str, messages: list): + """Create chat completion based on the model.""" + # Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" + provider_key, model_name = model.split(":", 1) + + # Use the correct provider instance created by the factory + provider = self.client.providers.get(provider_key) + if not provider: + # Add the providers to the ValueError + raise ValueError(f"Provider {provider_key} is not present in the client. Here are the providers: {self.client.providers}") + + # Delegate the chat completion to the correct provider's implementation + return provider.chat_completions_create(model_name, messages) + + diff --git a/aisuitealt/provider.py b/aisuitealt/provider.py new file mode 100644 index 00000000..4edd2dda --- /dev/null +++ b/aisuitealt/provider.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from enum import Enum +import importlib + +class LLMError(Exception): + """Custom exception for LLM errors.""" + def __init__(self, message): + super().__init__(message) + + +class Provider(ABC): + @abstractmethod + def chat_completions_create(self, model, messages): + """Abstract method for chat completion calls, to be implemented by each provider.""" + pass + + +class ProviderNames(Enum): + OPENAI = 'openai' + AWS_BEDROCK = 'aws-bedrock' + + +class ProviderFactory: + """Factory to register and create provider instances based on keys.""" + + _provider_modules = { + ProviderNames.OPENAI: 'providers.openai_provider', + ProviderNames.AWS_BEDROCK: 'providers.aws_bedrock_provider', + } + + @classmethod + def create_provider(cls, provider_key, config): + """Dynamically import and create an instance of a provider based on the provider key.""" + if not isinstance(provider_key, ProviderNames): + raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") + + module_name = cls._provider_modules.get(provider_key) + if not module_name: + raise ValueError(f"Provider {provider_key.value} is not supported") + + # Lazily load the module + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ImportError(f"Could not import module {module_name}: {str(e)}") + + # Instantiate the provider class + provider_class = getattr(module, cls._get_provider_class_name(provider_key)) + return provider_class(**config) + + @staticmethod + def _get_provider_class_name(provider_key): + """Map provider key to the corresponding class name in the module.""" + return { + ProviderNames.OPENAI: 'OpenAIProvider', + ProviderNames.AWS_BEDROCK: 'AWSBedrockProvider', + }[provider_key] diff --git a/aisuitealt/providers/__init__.py b/aisuitealt/providers/__init__.py new file mode 100644 index 00000000..5910d76d --- /dev/null +++ b/aisuitealt/providers/__init__.py @@ -0,0 +1 @@ +# Empty __init__.py diff --git a/aisuitealt/providers/aws_bedrock_provider.py b/aisuitealt/providers/aws_bedrock_provider.py new file mode 100644 index 00000000..5f46ea95 --- /dev/null +++ b/aisuitealt/providers/aws_bedrock_provider.py @@ -0,0 +1,22 @@ +from anthropic import AnthropicBedrock +from provider import Provider, LLMError + +class AWSBedrockProvider(Provider): + def __init__(self, access_key, secret_key, session_token, region): + self.client = AnthropicBedrock( + aws_access_key=access_key, + aws_secret_key=secret_key, + aws_session_token=session_token, + aws_region=region + ) + + def chat_completions_create(self, model, messages): + try: + response = self.client.messages.create( + model=model, + max_tokens=256, + messages=messages + ) + return response['choices'][0]['message']['content'] + except Exception as e: + raise LLMError(f"AWS Bedrock API error: {str(e)}") diff --git a/aisuitealt/providers/openai_provider.py b/aisuitealt/providers/openai_provider.py new file mode 100644 index 00000000..0926b81d --- /dev/null +++ b/aisuitealt/providers/openai_provider.py @@ -0,0 +1,16 @@ +import openai +from provider import Provider, LLMError + +class OpenAIProvider(Provider): + def __init__(self, api_key): + openai.api_key = api_key + + def chat_completions_create(self, model, messages): + try: + response = openai.ChatCompletion.create( + model=model, + messages=messages + ) + return response['choices'][0]['message']['content'] + except Exception as e: + raise LLMError(f"OpenAI API error: {str(e)}") diff --git a/aisuitealt/tests/__init__.py b/aisuitealt/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aisuitealt/tests/test_client.py b/aisuitealt/tests/test_client.py new file mode 100644 index 00000000..5705f09a --- /dev/null +++ b/aisuitealt/tests/test_client.py @@ -0,0 +1,47 @@ +import unittest +from unittest.mock import patch +from client import Client +from provider import ProviderNames + +class TestClient(unittest.TestCase): + + @patch('providers.openai_provider.OpenAIProvider.chat_completions_create') + @patch('providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create') + def test_client_chat_completions(self, mock_bedrock, mock_openai): + # Mock responses from providers + mock_openai.return_value = "OpenAI Response" + mock_bedrock.return_value = "AWS Bedrock Response" + + # Provider configurations + provider_configs = { + ProviderNames.OPENAI: { + 'api_key': 'test_openai_api_key' + }, + ProviderNames.AWS_BEDROCK: { + 'access_key': 'test_aws_access_key', + 'secret_key': 'test_aws_secret_key', + 'session_token': 'test_aws_session_token', + 'region': 'us-west-2' + } + } + + # Initialize the client + client = Client(provider_configs) + + # Test OpenAI model + openai_response = client.chat.completions.create(ProviderNames.OPENAI.value + ":" + "gpt-4o", [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"} + ]) + self.assertEqual(openai_response, "OpenAI Response") + mock_openai.assert_called_once() + + # Test AWS Bedrock model + bedrock_response = client.chat.completions.create(ProviderNames.AWS_BEDROCK.value + ":" + "claude-v3", [ + {"role": "user", "content": "Hello, world!"} + ]) + self.assertEqual(bedrock_response, "AWS Bedrock Response") + mock_bedrock.assert_called_once() + +if __name__ == '__main__': + unittest.main()