Skip to content

Commit fca12d3

Browse files
committed
Added older provider interface back for tests to pass.
Added back few older files. Replaced test_client with newer test code. Will need to port the older provider files to the new format. Till then keeping the older provider interface related tests.
1 parent d9ede71 commit fca12d3

18 files changed

+154
-123
lines changed

aisuite/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ def chat(self):
5151
class Chat:
5252
def __init__(self, client: "Client"):
5353
self.client = client
54+
self._completions = Completions(self.client)
5455

5556
@property
5657
def completions(self):
5758
"""Return the completions interface."""
58-
return Completions(self.client)
59+
return self._completions
5960

6061

6162
class Completions:

aisuite/framework/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .provider_interface import ProviderInterface
12
from .chat_completion_response import ChatCompletionResponse
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""The shared interface for model providers."""
2+
3+
4+
class ProviderInterface:
5+
"""Defines the expected behavior for provider-specific interfaces."""
6+
7+
def chat_completion_create(self, messages=None, model=None, temperature=0) -> None:
8+
"""Create a chat completion using the specified messages, model, and temperature.
9+
10+
This method must be implemented by subclasses to perform completions.
11+
12+
Args:
13+
----
14+
messages (list): The chat history.
15+
model (str): The identifier of the model to be used in the completion.
16+
temperature (float): The temperature to use in the completion.
17+
18+
Raises:
19+
------
20+
NotImplementedError: If this method has not been implemented by a subclass.
21+
22+
"""
23+
raise NotImplementedError(
24+
"Provider Interface has not implemented chat_completion_create()"
25+
)

aisuite/old_providers/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

aisuite/providers/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,13 @@
1-
# Empty __init__.py
1+
"""Provides the individual provider interfaces for each FM provider."""
2+
3+
from .anthropic_interface import AnthropicInterface
4+
from .aws_bedrock_interface import AWSBedrockInterface
5+
from .fireworks_interface import FireworksInterface
6+
from .groq_interface import GroqInterface
7+
from .mistral_interface import MistralInterface
8+
from .octo_interface import OctoInterface
9+
from .ollama_interface import OllamaInterface
10+
from .openai_interface import OpenAIInterface
11+
from .replicate_interface import ReplicateInterface
12+
from .together_interface import TogetherInterface
13+
from .google_interface import GoogleInterface

tests/client/test_client.py

Lines changed: 113 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,128 @@
1-
import pytest
2-
from aisuite.client.client import Client, AnthropicInterface
1+
import unittest
2+
from unittest.mock import patch
3+
from aisuite import Client
4+
from aisuite import ProviderNames
35

46

5-
def test_get_provider_interface_with_new_instance():
6-
"""Test that get_provider_interface creates a new instance of the interface."""
7-
client = Client()
8-
interface, model_name = client.get_provider_interface("anthropic:some-model:v1")
9-
assert isinstance(interface, AnthropicInterface)
10-
assert model_name == "some-model:v1"
11-
assert client.all_interfaces["anthropic"] == interface
7+
class TestClient(unittest.TestCase):
128

9+
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
10+
@patch(
11+
"aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create"
12+
)
13+
@patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create")
14+
@patch(
15+
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
16+
)
17+
def test_client_chat_completions(
18+
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai
19+
):
20+
# Mock responses from providers
21+
mock_openai.return_value = "OpenAI Response"
22+
mock_bedrock.return_value = "AWS Bedrock Response"
23+
mock_azure.return_value = "Azure Response"
24+
mock_anthropic.return_value = "Anthropic Response"
1325

14-
def test_get_provider_interface_with_existing_instance():
15-
"""Test that get_provider_interface returns an existing instance of the interface, if already created."""
16-
client = Client()
26+
# Provider configurations
27+
provider_configs = {
28+
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
29+
ProviderNames.AWS_BEDROCK: {
30+
"aws_access_key": "test_aws_access_key",
31+
"aws_secret_key": "test_aws_secret_key",
32+
"aws_session_token": "test_aws_session_token",
33+
"aws_region": "us-west-2",
34+
},
35+
ProviderNames.AZURE: {
36+
"api_key": "azure-api-key",
37+
},
38+
}
1739

18-
# New interface instance
19-
new_instance, _ = client.get_provider_interface("anthropic:some-model:v2")
40+
# Initialize the client
41+
client = Client()
42+
client.configure(provider_configs)
43+
messages = [
44+
{"role": "system", "content": "You are a helpful assistant."},
45+
{"role": "user", "content": "Who won the world series in 2020?"},
46+
]
2047

21-
# Call twice, get same instance back
22-
same_instance, _ = client.get_provider_interface("anthropic:some-model:v2")
48+
# Test OpenAI model
49+
open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o"
50+
openai_response = client.chat.completions.create(
51+
open_ai_model, messages=messages
52+
)
53+
self.assertEqual(openai_response, "OpenAI Response")
54+
mock_openai.assert_called_once()
2355

24-
assert new_instance is same_instance
56+
# Test AWS Bedrock model
57+
bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3"
58+
bedrock_response = client.chat.completions.create(
59+
bedrock_model, messages=messages
60+
)
61+
self.assertEqual(bedrock_response, "AWS Bedrock Response")
62+
mock_bedrock.assert_called_once()
2563

64+
azure_model = ProviderNames.AZURE + ":" + "azure-model"
65+
azure_response = client.chat.completions.create(azure_model, messages=messages)
66+
self.assertEqual(azure_response, "Azure Response")
67+
mock_azure.assert_called_once()
2668

27-
def test_get_provider_interface_with_invalid_format():
28-
client = Client()
69+
anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model"
70+
anthropic_response = client.chat.completions.create(
71+
anthropic_model, messages=messages
72+
)
73+
self.assertEqual(anthropic_response, "Anthropic Response")
74+
mock_anthropic.assert_called_once()
2975

30-
with pytest.raises(ValueError) as exc_info:
31-
client.get_provider_interface("invalid-model-no-colon")
76+
# Test that new instances of Completion are not created each time we make an inference call.
77+
compl_instance = client.chat.completions
78+
next_compl_instance = client.chat.completions
79+
assert compl_instance is next_compl_instance
3280

33-
assert "Expected ':' in model identifier" in str(exc_info.value)
81+
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
82+
def test_invalid_provider_in_client_config(self, mock_openai):
83+
# Testing an invalid provider name in the configuration
84+
invalid_provider_configs = {
85+
"INVALID_PROVIDER": {"api_key": "invalid_api_key"},
86+
}
3487

88+
# Expect ValueError when initializing Client with invalid provider
89+
with self.assertRaises(ValueError) as context:
90+
client = Client(invalid_provider_configs)
3591

36-
def test_get_provider_interface_with_unknown_interface():
37-
client = Client()
92+
# Verify the error message
93+
self.assertIn(
94+
"Provider INVALID_PROVIDER is not a valid ProviderNames enum",
95+
str(context.exception),
96+
)
3897

39-
with pytest.raises(Exception) as exc_info:
40-
client.get_provider_interface("unknown-interface:some-model")
98+
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
99+
def test_invalid_model_format_in_create(self, mock_openai):
100+
# Valid provider configurations
101+
provider_configs = {
102+
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
103+
}
41104

42-
assert (
43-
"Could not find factory to create interface for provider 'unknown-interface'"
44-
in str(exc_info.value)
45-
)
105+
# Initialize the client with valid provider
106+
client = Client(provider_configs)
107+
client.configure(provider_configs)
108+
109+
messages = [
110+
{"role": "system", "content": "You are a helpful assistant."},
111+
{"role": "user", "content": "Tell me a joke."},
112+
]
113+
114+
# Invalid model format
115+
invalid_model = "invalidmodel"
116+
117+
# Expect ValueError when calling create with invalid model format
118+
with self.assertRaises(ValueError) as context:
119+
client.chat.completions.create(invalid_model, messages=messages)
120+
121+
# Verify the error message
122+
self.assertIn(
123+
"Invalid model format. Expected 'provider:model'", str(context.exception)
124+
)
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

tests/client/test_client_basic.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)