|
1 |
| -import pytest |
2 |
| -from aisuite.client.client import Client, AnthropicInterface |
| 1 | +import unittest |
| 2 | +from unittest.mock import patch |
| 3 | +from aisuite import Client |
| 4 | +from aisuite import ProviderNames |
3 | 5 |
|
4 | 6 |
|
5 |
| -def test_get_provider_interface_with_new_instance(): |
6 |
| - """Test that get_provider_interface creates a new instance of the interface.""" |
7 |
| - client = Client() |
8 |
| - interface, model_name = client.get_provider_interface("anthropic:some-model:v1") |
9 |
| - assert isinstance(interface, AnthropicInterface) |
10 |
| - assert model_name == "some-model:v1" |
11 |
| - assert client.all_interfaces["anthropic"] == interface |
| 7 | +class TestClient(unittest.TestCase): |
12 | 8 |
|
| 9 | + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") |
| 10 | + @patch( |
| 11 | + "aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" |
| 12 | + ) |
| 13 | + @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") |
| 14 | + @patch( |
| 15 | + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" |
| 16 | + ) |
| 17 | + def test_client_chat_completions( |
| 18 | + self, mock_anthropic, mock_azure, mock_bedrock, mock_openai |
| 19 | + ): |
| 20 | + # Mock responses from providers |
| 21 | + mock_openai.return_value = "OpenAI Response" |
| 22 | + mock_bedrock.return_value = "AWS Bedrock Response" |
| 23 | + mock_azure.return_value = "Azure Response" |
| 24 | + mock_anthropic.return_value = "Anthropic Response" |
13 | 25 |
|
14 |
| -def test_get_provider_interface_with_existing_instance(): |
15 |
| - """Test that get_provider_interface returns an existing instance of the interface, if already created.""" |
16 |
| - client = Client() |
| 26 | + # Provider configurations |
| 27 | + provider_configs = { |
| 28 | + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, |
| 29 | + ProviderNames.AWS_BEDROCK: { |
| 30 | + "aws_access_key": "test_aws_access_key", |
| 31 | + "aws_secret_key": "test_aws_secret_key", |
| 32 | + "aws_session_token": "test_aws_session_token", |
| 33 | + "aws_region": "us-west-2", |
| 34 | + }, |
| 35 | + ProviderNames.AZURE: { |
| 36 | + "api_key": "azure-api-key", |
| 37 | + }, |
| 38 | + } |
17 | 39 |
|
18 |
| - # New interface instance |
19 |
| - new_instance, _ = client.get_provider_interface("anthropic:some-model:v2") |
| 40 | + # Initialize the client |
| 41 | + client = Client() |
| 42 | + client.configure(provider_configs) |
| 43 | + messages = [ |
| 44 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 45 | + {"role": "user", "content": "Who won the world series in 2020?"}, |
| 46 | + ] |
20 | 47 |
|
21 |
| - # Call twice, get same instance back |
22 |
| - same_instance, _ = client.get_provider_interface("anthropic:some-model:v2") |
| 48 | + # Test OpenAI model |
| 49 | + open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o" |
| 50 | + openai_response = client.chat.completions.create( |
| 51 | + open_ai_model, messages=messages |
| 52 | + ) |
| 53 | + self.assertEqual(openai_response, "OpenAI Response") |
| 54 | + mock_openai.assert_called_once() |
23 | 55 |
|
24 |
| - assert new_instance is same_instance |
| 56 | + # Test AWS Bedrock model |
| 57 | + bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3" |
| 58 | + bedrock_response = client.chat.completions.create( |
| 59 | + bedrock_model, messages=messages |
| 60 | + ) |
| 61 | + self.assertEqual(bedrock_response, "AWS Bedrock Response") |
| 62 | + mock_bedrock.assert_called_once() |
25 | 63 |
|
| 64 | + azure_model = ProviderNames.AZURE + ":" + "azure-model" |
| 65 | + azure_response = client.chat.completions.create(azure_model, messages=messages) |
| 66 | + self.assertEqual(azure_response, "Azure Response") |
| 67 | + mock_azure.assert_called_once() |
26 | 68 |
|
27 |
| -def test_get_provider_interface_with_invalid_format(): |
28 |
| - client = Client() |
| 69 | + anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" |
| 70 | + anthropic_response = client.chat.completions.create( |
| 71 | + anthropic_model, messages=messages |
| 72 | + ) |
| 73 | + self.assertEqual(anthropic_response, "Anthropic Response") |
| 74 | + mock_anthropic.assert_called_once() |
29 | 75 |
|
30 |
| - with pytest.raises(ValueError) as exc_info: |
31 |
| - client.get_provider_interface("invalid-model-no-colon") |
| 76 | + # Test that new instances of Completion are not created each time we make an inference call. |
| 77 | + compl_instance = client.chat.completions |
| 78 | + next_compl_instance = client.chat.completions |
| 79 | + assert compl_instance is next_compl_instance |
32 | 80 |
|
33 |
| - assert "Expected ':' in model identifier" in str(exc_info.value) |
| 81 | + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") |
| 82 | + def test_invalid_provider_in_client_config(self, mock_openai): |
| 83 | + # Testing an invalid provider name in the configuration |
| 84 | + invalid_provider_configs = { |
| 85 | + "INVALID_PROVIDER": {"api_key": "invalid_api_key"}, |
| 86 | + } |
34 | 87 |
|
| 88 | + # Expect ValueError when initializing Client with invalid provider |
| 89 | + with self.assertRaises(ValueError) as context: |
| 90 | + client = Client(invalid_provider_configs) |
35 | 91 |
|
36 |
| -def test_get_provider_interface_with_unknown_interface(): |
37 |
| - client = Client() |
| 92 | + # Verify the error message |
| 93 | + self.assertIn( |
| 94 | + "Provider INVALID_PROVIDER is not a valid ProviderNames enum", |
| 95 | + str(context.exception), |
| 96 | + ) |
38 | 97 |
|
39 |
| - with pytest.raises(Exception) as exc_info: |
40 |
| - client.get_provider_interface("unknown-interface:some-model") |
| 98 | + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") |
| 99 | + def test_invalid_model_format_in_create(self, mock_openai): |
| 100 | + # Valid provider configurations |
| 101 | + provider_configs = { |
| 102 | + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, |
| 103 | + } |
41 | 104 |
|
42 |
| - assert ( |
43 |
| - "Could not find factory to create interface for provider 'unknown-interface'" |
44 |
| - in str(exc_info.value) |
45 |
| - ) |
| 105 | + # Initialize the client with valid provider |
| 106 | + client = Client(provider_configs) |
| 107 | + client.configure(provider_configs) |
| 108 | + |
| 109 | + messages = [ |
| 110 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 111 | + {"role": "user", "content": "Tell me a joke."}, |
| 112 | + ] |
| 113 | + |
| 114 | + # Invalid model format |
| 115 | + invalid_model = "invalidmodel" |
| 116 | + |
| 117 | + # Expect ValueError when calling create with invalid model format |
| 118 | + with self.assertRaises(ValueError) as context: |
| 119 | + client.chat.completions.create(invalid_model, messages=messages) |
| 120 | + |
| 121 | + # Verify the error message |
| 122 | + self.assertIn( |
| 123 | + "Invalid model format. Expected 'provider:model'", str(context.exception) |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == "__main__": |
| 128 | + unittest.main() |
0 commit comments