Skip to content

Commit

Permalink
Added older provider interface back for tests to pass.
Browse files Browse the repository at this point in the history
Added back few older files.
Replaced test_client with newer test code.
Will need to port the older provider files
to the new format. Till then keeping the
older provider interface related tests.
  • Loading branch information
rohit-rptless committed Sep 11, 2024
1 parent d9ede71 commit fca12d3
Show file tree
Hide file tree
Showing 18 changed files with 154 additions and 123 deletions.
3 changes: 2 additions & 1 deletion aisuite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def chat(self):
class Chat:
def __init__(self, client: "Client"):
self.client = client
self._completions = Completions(self.client)

@property
def completions(self):
"""Return the completions interface."""
return Completions(self.client)
return self._completions


class Completions:
Expand Down
1 change: 1 addition & 0 deletions aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
25 changes: 25 additions & 0 deletions aisuite/framework/provider_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""The shared interface for model providers."""


class ProviderInterface:
"""Defines the expected behavior for provider-specific interfaces."""

def chat_completion_create(self, messages=None, model=None, temperature=0) -> None:
"""Create a chat completion using the specified messages, model, and temperature.
This method must be implemented by subclasses to perform completions.
Args:
----
messages (list): The chat history.
model (str): The identifier of the model to be used in the completion.
temperature (float): The temperature to use in the completion.
Raises:
------
NotImplementedError: If this method has not been implemented by a subclass.
"""
raise NotImplementedError(
"Provider Interface has not implemented chat_completion_create()"
)
13 changes: 0 additions & 13 deletions aisuite/old_providers/__init__.py

This file was deleted.

14 changes: 13 additions & 1 deletion aisuite/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
# Empty __init__.py
"""Provides the individual provider interfaces for each FM provider."""

from .anthropic_interface import AnthropicInterface
from .aws_bedrock_interface import AWSBedrockInterface
from .fireworks_interface import FireworksInterface
from .groq_interface import GroqInterface
from .mistral_interface import MistralInterface
from .octo_interface import OctoInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .google_interface import GoogleInterface
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
143 changes: 113 additions & 30 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,128 @@
import pytest
from aisuite.client.client import Client, AnthropicInterface
import unittest
from unittest.mock import patch
from aisuite import Client
from aisuite import ProviderNames


def test_get_provider_interface_with_new_instance():
"""Test that get_provider_interface creates a new instance of the interface."""
client = Client()
interface, model_name = client.get_provider_interface("anthropic:some-model:v1")
assert isinstance(interface, AnthropicInterface)
assert model_name == "some-model:v1"
assert client.all_interfaces["anthropic"] == interface
class TestClient(unittest.TestCase):

@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
@patch(
"aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create"
)
@patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create")
@patch(
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
)
def test_client_chat_completions(
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai
):
# Mock responses from providers
mock_openai.return_value = "OpenAI Response"
mock_bedrock.return_value = "AWS Bedrock Response"
mock_azure.return_value = "Azure Response"
mock_anthropic.return_value = "Anthropic Response"

def test_get_provider_interface_with_existing_instance():
"""Test that get_provider_interface returns an existing instance of the interface, if already created."""
client = Client()
# Provider configurations
provider_configs = {
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
ProviderNames.AWS_BEDROCK: {
"aws_access_key": "test_aws_access_key",
"aws_secret_key": "test_aws_secret_key",
"aws_session_token": "test_aws_session_token",
"aws_region": "us-west-2",
},
ProviderNames.AZURE: {
"api_key": "azure-api-key",
},
}

# New interface instance
new_instance, _ = client.get_provider_interface("anthropic:some-model:v2")
# Initialize the client
client = Client()
client.configure(provider_configs)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
]

# Call twice, get same instance back
same_instance, _ = client.get_provider_interface("anthropic:some-model:v2")
# Test OpenAI model
open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o"
openai_response = client.chat.completions.create(
open_ai_model, messages=messages
)
self.assertEqual(openai_response, "OpenAI Response")
mock_openai.assert_called_once()

assert new_instance is same_instance
# Test AWS Bedrock model
bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3"
bedrock_response = client.chat.completions.create(
bedrock_model, messages=messages
)
self.assertEqual(bedrock_response, "AWS Bedrock Response")
mock_bedrock.assert_called_once()

azure_model = ProviderNames.AZURE + ":" + "azure-model"
azure_response = client.chat.completions.create(azure_model, messages=messages)
self.assertEqual(azure_response, "Azure Response")
mock_azure.assert_called_once()

def test_get_provider_interface_with_invalid_format():
client = Client()
anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model"
anthropic_response = client.chat.completions.create(
anthropic_model, messages=messages
)
self.assertEqual(anthropic_response, "Anthropic Response")
mock_anthropic.assert_called_once()

with pytest.raises(ValueError) as exc_info:
client.get_provider_interface("invalid-model-no-colon")
# Test that new instances of Completion are not created each time we make an inference call.
compl_instance = client.chat.completions
next_compl_instance = client.chat.completions
assert compl_instance is next_compl_instance

assert "Expected ':' in model identifier" in str(exc_info.value)
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
def test_invalid_provider_in_client_config(self, mock_openai):
# Testing an invalid provider name in the configuration
invalid_provider_configs = {
"INVALID_PROVIDER": {"api_key": "invalid_api_key"},
}

# Expect ValueError when initializing Client with invalid provider
with self.assertRaises(ValueError) as context:
client = Client(invalid_provider_configs)

def test_get_provider_interface_with_unknown_interface():
client = Client()
# Verify the error message
self.assertIn(
"Provider INVALID_PROVIDER is not a valid ProviderNames enum",
str(context.exception),
)

with pytest.raises(Exception) as exc_info:
client.get_provider_interface("unknown-interface:some-model")
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
def test_invalid_model_format_in_create(self, mock_openai):
# Valid provider configurations
provider_configs = {
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
}

assert (
"Could not find factory to create interface for provider 'unknown-interface'"
in str(exc_info.value)
)
# Initialize the client with valid provider
client = Client(provider_configs)
client.configure(provider_configs)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me a joke."},
]

# Invalid model format
invalid_model = "invalidmodel"

# Expect ValueError when calling create with invalid model format
with self.assertRaises(ValueError) as context:
client.chat.completions.create(invalid_model, messages=messages)

# Verify the error message
self.assertIn(
"Invalid model format. Expected 'provider:model'", str(context.exception)
)


if __name__ == "__main__":
unittest.main()
78 changes: 0 additions & 78 deletions tests/client/test_client_basic.py

This file was deleted.

0 comments on commit fca12d3

Please sign in to comment.