From fca12d373eec55845cccf99c043a02ed53e5c194 Mon Sep 17 00:00:00 2001 From: rohit-rptless Date: Wed, 11 Sep 2024 11:35:53 -0700 Subject: [PATCH] Added older provider interface back for tests to pass. Added back few older files. Replaced test_client with newer test code. Will need to port the older provider files to the new format. Till then keeping the older provider interface related tests. --- aisuite/client.py | 3 +- aisuite/framework/__init__.py | 1 + aisuite/framework/provider_interface.py | 25 +++ aisuite/old_providers/__init__.py | 13 -- aisuite/providers/__init__.py | 14 +- .../anthropic_interface.py | 0 .../aws_bedrock_interface.py | 0 .../fireworks_interface.py | 0 .../google_interface.py | 0 .../groq_interface.py | 0 .../mistral_interface.py | 0 .../octo_interface.py | 0 .../ollama_interface.py | 0 .../openai_interface.py | 0 .../replicate_interface.py | 0 .../together_interface.py | 0 tests/client/test_client.py | 143 ++++++++++++++---- tests/client/test_client_basic.py | 78 ---------- 18 files changed, 154 insertions(+), 123 deletions(-) create mode 100644 aisuite/framework/provider_interface.py delete mode 100644 aisuite/old_providers/__init__.py rename aisuite/{old_providers => providers}/anthropic_interface.py (100%) rename aisuite/{old_providers => providers}/aws_bedrock_interface.py (100%) rename aisuite/{old_providers => providers}/fireworks_interface.py (100%) rename aisuite/{old_providers => providers}/google_interface.py (100%) rename aisuite/{old_providers => providers}/groq_interface.py (100%) rename aisuite/{old_providers => providers}/mistral_interface.py (100%) rename aisuite/{old_providers => providers}/octo_interface.py (100%) rename aisuite/{old_providers => providers}/ollama_interface.py (100%) rename aisuite/{old_providers => providers}/openai_interface.py (100%) rename aisuite/{old_providers => providers}/replicate_interface.py (100%) rename aisuite/{old_providers => providers}/together_interface.py (100%) delete mode 100644 tests/client/test_client_basic.py diff --git a/aisuite/client.py b/aisuite/client.py index 7dd088a9..e175cbe1 100644 --- a/aisuite/client.py +++ b/aisuite/client.py @@ -51,11 +51,12 @@ def chat(self): class Chat: def __init__(self, client: "Client"): self.client = client + self._completions = Completions(self.client) @property def completions(self): """Return the completions interface.""" - return Completions(self.client) + return self._completions class Completions: diff --git a/aisuite/framework/__init__.py b/aisuite/framework/__init__.py index 7cb4a11f..aad7ebd2 100644 --- a/aisuite/framework/__init__.py +++ b/aisuite/framework/__init__.py @@ -1 +1,2 @@ +from .provider_interface import ProviderInterface from .chat_completion_response import ChatCompletionResponse diff --git a/aisuite/framework/provider_interface.py b/aisuite/framework/provider_interface.py new file mode 100644 index 00000000..3b6db766 --- /dev/null +++ b/aisuite/framework/provider_interface.py @@ -0,0 +1,25 @@ +"""The shared interface for model providers.""" + + +class ProviderInterface: + """Defines the expected behavior for provider-specific interfaces.""" + + def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: + """Create a chat completion using the specified messages, model, and temperature. + + This method must be implemented by subclasses to perform completions. + + Args: + ---- + messages (list): The chat history. + model (str): The identifier of the model to be used in the completion. + temperature (float): The temperature to use in the completion. + + Raises: + ------ + NotImplementedError: If this method has not been implemented by a subclass. + + """ + raise NotImplementedError( + "Provider Interface has not implemented chat_completion_create()" + ) diff --git a/aisuite/old_providers/__init__.py b/aisuite/old_providers/__init__.py deleted file mode 100644 index 816f790e..00000000 --- a/aisuite/old_providers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Provides the individual provider interfaces for each FM provider.""" - -from .anthropic_interface import AnthropicInterface -from .aws_bedrock_interface import AWSBedrockInterface -from .fireworks_interface import FireworksInterface -from .groq_interface import GroqInterface -from .mistral_interface import MistralInterface -from .octo_interface import OctoInterface -from .ollama_interface import OllamaInterface -from .openai_interface import OpenAIInterface -from .replicate_interface import ReplicateInterface -from .together_interface import TogetherInterface -from .google_interface import GoogleInterface diff --git a/aisuite/providers/__init__.py b/aisuite/providers/__init__.py index 5910d76d..816f790e 100644 --- a/aisuite/providers/__init__.py +++ b/aisuite/providers/__init__.py @@ -1 +1,13 @@ -# Empty __init__.py +"""Provides the individual provider interfaces for each FM provider.""" + +from .anthropic_interface import AnthropicInterface +from .aws_bedrock_interface import AWSBedrockInterface +from .fireworks_interface import FireworksInterface +from .groq_interface import GroqInterface +from .mistral_interface import MistralInterface +from .octo_interface import OctoInterface +from .ollama_interface import OllamaInterface +from .openai_interface import OpenAIInterface +from .replicate_interface import ReplicateInterface +from .together_interface import TogetherInterface +from .google_interface import GoogleInterface diff --git a/aisuite/old_providers/anthropic_interface.py b/aisuite/providers/anthropic_interface.py similarity index 100% rename from aisuite/old_providers/anthropic_interface.py rename to aisuite/providers/anthropic_interface.py diff --git a/aisuite/old_providers/aws_bedrock_interface.py b/aisuite/providers/aws_bedrock_interface.py similarity index 100% rename from aisuite/old_providers/aws_bedrock_interface.py rename to aisuite/providers/aws_bedrock_interface.py diff --git a/aisuite/old_providers/fireworks_interface.py b/aisuite/providers/fireworks_interface.py similarity index 100% rename from aisuite/old_providers/fireworks_interface.py rename to aisuite/providers/fireworks_interface.py diff --git a/aisuite/old_providers/google_interface.py b/aisuite/providers/google_interface.py similarity index 100% rename from aisuite/old_providers/google_interface.py rename to aisuite/providers/google_interface.py diff --git a/aisuite/old_providers/groq_interface.py b/aisuite/providers/groq_interface.py similarity index 100% rename from aisuite/old_providers/groq_interface.py rename to aisuite/providers/groq_interface.py diff --git a/aisuite/old_providers/mistral_interface.py b/aisuite/providers/mistral_interface.py similarity index 100% rename from aisuite/old_providers/mistral_interface.py rename to aisuite/providers/mistral_interface.py diff --git a/aisuite/old_providers/octo_interface.py b/aisuite/providers/octo_interface.py similarity index 100% rename from aisuite/old_providers/octo_interface.py rename to aisuite/providers/octo_interface.py diff --git a/aisuite/old_providers/ollama_interface.py b/aisuite/providers/ollama_interface.py similarity index 100% rename from aisuite/old_providers/ollama_interface.py rename to aisuite/providers/ollama_interface.py diff --git a/aisuite/old_providers/openai_interface.py b/aisuite/providers/openai_interface.py similarity index 100% rename from aisuite/old_providers/openai_interface.py rename to aisuite/providers/openai_interface.py diff --git a/aisuite/old_providers/replicate_interface.py b/aisuite/providers/replicate_interface.py similarity index 100% rename from aisuite/old_providers/replicate_interface.py rename to aisuite/providers/replicate_interface.py diff --git a/aisuite/old_providers/together_interface.py b/aisuite/providers/together_interface.py similarity index 100% rename from aisuite/old_providers/together_interface.py rename to aisuite/providers/together_interface.py diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 1232a37d..077974e6 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,45 +1,128 @@ -import pytest -from aisuite.client.client import Client, AnthropicInterface +import unittest +from unittest.mock import patch +from aisuite import Client +from aisuite import ProviderNames -def test_get_provider_interface_with_new_instance(): - """Test that get_provider_interface creates a new instance of the interface.""" - client = Client() - interface, model_name = client.get_provider_interface("anthropic:some-model:v1") - assert isinstance(interface, AnthropicInterface) - assert model_name == "some-model:v1" - assert client.all_interfaces["anthropic"] == interface +class TestClient(unittest.TestCase): + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + @patch( + "aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" + ) + @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") + @patch( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" + ) + def test_client_chat_completions( + self, mock_anthropic, mock_azure, mock_bedrock, mock_openai + ): + # Mock responses from providers + mock_openai.return_value = "OpenAI Response" + mock_bedrock.return_value = "AWS Bedrock Response" + mock_azure.return_value = "Azure Response" + mock_anthropic.return_value = "Anthropic Response" -def test_get_provider_interface_with_existing_instance(): - """Test that get_provider_interface returns an existing instance of the interface, if already created.""" - client = Client() + # Provider configurations + provider_configs = { + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, + ProviderNames.AWS_BEDROCK: { + "aws_access_key": "test_aws_access_key", + "aws_secret_key": "test_aws_secret_key", + "aws_session_token": "test_aws_session_token", + "aws_region": "us-west-2", + }, + ProviderNames.AZURE: { + "api_key": "azure-api-key", + }, + } - # New interface instance - new_instance, _ = client.get_provider_interface("anthropic:some-model:v2") + # Initialize the client + client = Client() + client.configure(provider_configs) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ] - # Call twice, get same instance back - same_instance, _ = client.get_provider_interface("anthropic:some-model:v2") + # Test OpenAI model + open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o" + openai_response = client.chat.completions.create( + open_ai_model, messages=messages + ) + self.assertEqual(openai_response, "OpenAI Response") + mock_openai.assert_called_once() - assert new_instance is same_instance + # Test AWS Bedrock model + bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3" + bedrock_response = client.chat.completions.create( + bedrock_model, messages=messages + ) + self.assertEqual(bedrock_response, "AWS Bedrock Response") + mock_bedrock.assert_called_once() + azure_model = ProviderNames.AZURE + ":" + "azure-model" + azure_response = client.chat.completions.create(azure_model, messages=messages) + self.assertEqual(azure_response, "Azure Response") + mock_azure.assert_called_once() -def test_get_provider_interface_with_invalid_format(): - client = Client() + anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" + anthropic_response = client.chat.completions.create( + anthropic_model, messages=messages + ) + self.assertEqual(anthropic_response, "Anthropic Response") + mock_anthropic.assert_called_once() - with pytest.raises(ValueError) as exc_info: - client.get_provider_interface("invalid-model-no-colon") + # Test that new instances of Completion are not created each time we make an inference call. + compl_instance = client.chat.completions + next_compl_instance = client.chat.completions + assert compl_instance is next_compl_instance - assert "Expected ':' in model identifier" in str(exc_info.value) + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + def test_invalid_provider_in_client_config(self, mock_openai): + # Testing an invalid provider name in the configuration + invalid_provider_configs = { + "INVALID_PROVIDER": {"api_key": "invalid_api_key"}, + } + # Expect ValueError when initializing Client with invalid provider + with self.assertRaises(ValueError) as context: + client = Client(invalid_provider_configs) -def test_get_provider_interface_with_unknown_interface(): - client = Client() + # Verify the error message + self.assertIn( + "Provider INVALID_PROVIDER is not a valid ProviderNames enum", + str(context.exception), + ) - with pytest.raises(Exception) as exc_info: - client.get_provider_interface("unknown-interface:some-model") + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + def test_invalid_model_format_in_create(self, mock_openai): + # Valid provider configurations + provider_configs = { + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, + } - assert ( - "Could not find factory to create interface for provider 'unknown-interface'" - in str(exc_info.value) - ) + # Initialize the client with valid provider + client = Client(provider_configs) + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format + with self.assertRaises(ValueError) as context: + client.chat.completions.create(invalid_model, messages=messages) + + # Verify the error message + self.assertIn( + "Invalid model format. Expected 'provider:model'", str(context.exception) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/client/test_client_basic.py b/tests/client/test_client_basic.py deleted file mode 100644 index f8d8e77f..00000000 --- a/tests/client/test_client_basic.py +++ /dev/null @@ -1,78 +0,0 @@ -import unittest -from unittest.mock import patch -from aisuite import Client -from aisuite import ProviderNames - - -class TestClient(unittest.TestCase): - - @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") - @patch( - "aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" - ) - @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") - @patch( - "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" - ) - def test_client_chat_completions( - self, mock_anthropic, mock_azure, mock_bedrock, mock_openai - ): - # Mock responses from providers - mock_openai.return_value = "OpenAI Response" - mock_bedrock.return_value = "AWS Bedrock Response" - mock_azure.return_value = "Azure Response" - mock_anthropic.return_value = "Anthropic Response" - - # Provider configurations - provider_configs = { - ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, - ProviderNames.AWS_BEDROCK: { - "aws_access_key": "test_aws_access_key", - "aws_secret_key": "test_aws_secret_key", - "aws_session_token": "test_aws_session_token", - "aws_region": "us-west-2", - }, - ProviderNames.AZURE: { - "api_key": "azure-api-key", - }, - } - - # Initialize the client - client = Client() - client.configure(provider_configs) - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Who won the world series in 2020?"}, - ] - - # Test OpenAI model - open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o" - openai_response = client.chat.completions.create( - open_ai_model, messages=messages - ) - self.assertEqual(openai_response, "OpenAI Response") - mock_openai.assert_called_once() - - # Test AWS Bedrock model - bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3" - bedrock_response = client.chat.completions.create( - bedrock_model, messages=messages - ) - self.assertEqual(bedrock_response, "AWS Bedrock Response") - mock_bedrock.assert_called_once() - - azure_model = ProviderNames.AZURE + ":" + "azure-model" - azure_response = client.chat.completions.create(azure_model, messages=messages) - self.assertEqual(azure_response, "Azure Response") - mock_azure.assert_called_once() - - anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" - anthropic_response = client.chat.completions.create( - anthropic_model, messages=messages - ) - self.assertEqual(anthropic_response, "Anthropic Response") - mock_anthropic.assert_called_once() - - -if __name__ == "__main__": - unittest.main()