-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added older provider interface back for tests to pass.
Added back few older files. Replaced test_client with newer test code. Will need to port the older provider files to the new format. Till then keeping the older provider interface related tests.
- Loading branch information
1 parent
d9ede71
commit fca12d3
Showing
18 changed files
with
154 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .provider_interface import ProviderInterface | ||
from .chat_completion_response import ChatCompletionResponse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
"""The shared interface for model providers.""" | ||
|
||
|
||
class ProviderInterface: | ||
"""Defines the expected behavior for provider-specific interfaces.""" | ||
|
||
def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: | ||
"""Create a chat completion using the specified messages, model, and temperature. | ||
This method must be implemented by subclasses to perform completions. | ||
Args: | ||
---- | ||
messages (list): The chat history. | ||
model (str): The identifier of the model to be used in the completion. | ||
temperature (float): The temperature to use in the completion. | ||
Raises: | ||
------ | ||
NotImplementedError: If this method has not been implemented by a subclass. | ||
""" | ||
raise NotImplementedError( | ||
"Provider Interface has not implemented chat_completion_create()" | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,13 @@ | ||
# Empty __init__.py | ||
"""Provides the individual provider interfaces for each FM provider.""" | ||
|
||
from .anthropic_interface import AnthropicInterface | ||
from .aws_bedrock_interface import AWSBedrockInterface | ||
from .fireworks_interface import FireworksInterface | ||
from .groq_interface import GroqInterface | ||
from .mistral_interface import MistralInterface | ||
from .octo_interface import OctoInterface | ||
from .ollama_interface import OllamaInterface | ||
from .openai_interface import OpenAIInterface | ||
from .replicate_interface import ReplicateInterface | ||
from .together_interface import TogetherInterface | ||
from .google_interface import GoogleInterface |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,128 @@ | ||
import pytest | ||
from aisuite.client.client import Client, AnthropicInterface | ||
import unittest | ||
from unittest.mock import patch | ||
from aisuite import Client | ||
from aisuite import ProviderNames | ||
|
||
|
||
def test_get_provider_interface_with_new_instance(): | ||
"""Test that get_provider_interface creates a new instance of the interface.""" | ||
client = Client() | ||
interface, model_name = client.get_provider_interface("anthropic:some-model:v1") | ||
assert isinstance(interface, AnthropicInterface) | ||
assert model_name == "some-model:v1" | ||
assert client.all_interfaces["anthropic"] == interface | ||
class TestClient(unittest.TestCase): | ||
|
||
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") | ||
@patch( | ||
"aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" | ||
) | ||
@patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") | ||
@patch( | ||
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" | ||
) | ||
def test_client_chat_completions( | ||
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai | ||
): | ||
# Mock responses from providers | ||
mock_openai.return_value = "OpenAI Response" | ||
mock_bedrock.return_value = "AWS Bedrock Response" | ||
mock_azure.return_value = "Azure Response" | ||
mock_anthropic.return_value = "Anthropic Response" | ||
|
||
def test_get_provider_interface_with_existing_instance(): | ||
"""Test that get_provider_interface returns an existing instance of the interface, if already created.""" | ||
client = Client() | ||
# Provider configurations | ||
provider_configs = { | ||
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, | ||
ProviderNames.AWS_BEDROCK: { | ||
"aws_access_key": "test_aws_access_key", | ||
"aws_secret_key": "test_aws_secret_key", | ||
"aws_session_token": "test_aws_session_token", | ||
"aws_region": "us-west-2", | ||
}, | ||
ProviderNames.AZURE: { | ||
"api_key": "azure-api-key", | ||
}, | ||
} | ||
|
||
# New interface instance | ||
new_instance, _ = client.get_provider_interface("anthropic:some-model:v2") | ||
# Initialize the client | ||
client = Client() | ||
client.configure(provider_configs) | ||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "Who won the world series in 2020?"}, | ||
] | ||
|
||
# Call twice, get same instance back | ||
same_instance, _ = client.get_provider_interface("anthropic:some-model:v2") | ||
# Test OpenAI model | ||
open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o" | ||
openai_response = client.chat.completions.create( | ||
open_ai_model, messages=messages | ||
) | ||
self.assertEqual(openai_response, "OpenAI Response") | ||
mock_openai.assert_called_once() | ||
|
||
assert new_instance is same_instance | ||
# Test AWS Bedrock model | ||
bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3" | ||
bedrock_response = client.chat.completions.create( | ||
bedrock_model, messages=messages | ||
) | ||
self.assertEqual(bedrock_response, "AWS Bedrock Response") | ||
mock_bedrock.assert_called_once() | ||
|
||
azure_model = ProviderNames.AZURE + ":" + "azure-model" | ||
azure_response = client.chat.completions.create(azure_model, messages=messages) | ||
self.assertEqual(azure_response, "Azure Response") | ||
mock_azure.assert_called_once() | ||
|
||
def test_get_provider_interface_with_invalid_format(): | ||
client = Client() | ||
anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" | ||
anthropic_response = client.chat.completions.create( | ||
anthropic_model, messages=messages | ||
) | ||
self.assertEqual(anthropic_response, "Anthropic Response") | ||
mock_anthropic.assert_called_once() | ||
|
||
with pytest.raises(ValueError) as exc_info: | ||
client.get_provider_interface("invalid-model-no-colon") | ||
# Test that new instances of Completion are not created each time we make an inference call. | ||
compl_instance = client.chat.completions | ||
next_compl_instance = client.chat.completions | ||
assert compl_instance is next_compl_instance | ||
|
||
assert "Expected ':' in model identifier" in str(exc_info.value) | ||
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") | ||
def test_invalid_provider_in_client_config(self, mock_openai): | ||
# Testing an invalid provider name in the configuration | ||
invalid_provider_configs = { | ||
"INVALID_PROVIDER": {"api_key": "invalid_api_key"}, | ||
} | ||
|
||
# Expect ValueError when initializing Client with invalid provider | ||
with self.assertRaises(ValueError) as context: | ||
client = Client(invalid_provider_configs) | ||
|
||
def test_get_provider_interface_with_unknown_interface(): | ||
client = Client() | ||
# Verify the error message | ||
self.assertIn( | ||
"Provider INVALID_PROVIDER is not a valid ProviderNames enum", | ||
str(context.exception), | ||
) | ||
|
||
with pytest.raises(Exception) as exc_info: | ||
client.get_provider_interface("unknown-interface:some-model") | ||
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") | ||
def test_invalid_model_format_in_create(self, mock_openai): | ||
# Valid provider configurations | ||
provider_configs = { | ||
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, | ||
} | ||
|
||
assert ( | ||
"Could not find factory to create interface for provider 'unknown-interface'" | ||
in str(exc_info.value) | ||
) | ||
# Initialize the client with valid provider | ||
client = Client(provider_configs) | ||
client.configure(provider_configs) | ||
|
||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "Tell me a joke."}, | ||
] | ||
|
||
# Invalid model format | ||
invalid_model = "invalidmodel" | ||
|
||
# Expect ValueError when calling create with invalid model format | ||
with self.assertRaises(ValueError) as context: | ||
client.chat.completions.create(invalid_model, messages=messages) | ||
|
||
# Verify the error message | ||
self.assertIn( | ||
"Invalid model format. Expected 'provider:model'", str(context.exception) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file was deleted.
Oops, something went wrong.