Skip to content

Commit 9a89a02

Browse files
committed
Test cases for Providers.
Test case checks for translation into correct set of messages and additional parameters being passed to the respective SDK calls.
1 parent 9c2680e commit 9a89a02

File tree

4 files changed

+105
-2
lines changed

4 files changed

+105
-2
lines changed

aisuitealt/provider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def chat_completions_create(self, model, messages):
1818
class ProviderNames(str, Enum):
1919
OPENAI = 'openai'
2020
AWS_BEDROCK = 'aws-bedrock'
21+
ANTHROPIC = 'anthropic'
2122

2223

2324
class ProviderFactory:
@@ -26,6 +27,7 @@ class ProviderFactory:
2627
_provider_modules = {
2728
ProviderNames.OPENAI: 'providers.openai_provider',
2829
ProviderNames.AWS_BEDROCK: 'providers.aws_bedrock_provider',
30+
ProviderNames.ANTHROPIC: 'providers.anthropic_provider',
2931
}
3032

3133
@classmethod
@@ -54,4 +56,5 @@ def _get_provider_class_name(provider_key):
5456
return {
5557
ProviderNames.OPENAI: 'OpenAIProvider',
5658
ProviderNames.AWS_BEDROCK: 'AWSBedrockProvider',
59+
ProviderNames.ANTHROPIC: 'AnthropicProvider',
5760
}[provider_key]

aisuitealt/providers/aws_bedrock_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def __init__(self, **config):
99
"""
1010
# Anthropic Bedrock client will use the default AWS credential providers, such as
1111
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
12-
# Any overrides from the user is passed to the constructor.
13-
self.client = boto3.client("bedrock-runtime", **config)
12+
# It does not like parameters passed to the constructor.
13+
self.client = boto3.client("bedrock-runtime")
1414
# Maintain a list of Inference Parameters which Bedrock supports.
1515
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html
1616
self.inference_parameters = ['maxTokens', 'temperature', 'topP', 'stopSequences']
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
from client import Client
4+
from provider import ProviderNames
5+
6+
class TestAnthropicProvider(unittest.TestCase):
7+
8+
@patch('anthropic.Anthropic')
9+
def test_anthropic_chat_completions(self, mock_anthropic):
10+
"""Test that correct parameters are passed to Anthropic's chat completions API."""
11+
# Mocking the Anthropic client
12+
mock_anthropic_client = MagicMock()
13+
mock_anthropic.return_value = mock_anthropic_client
14+
15+
# Mock response from the Anthropic API
16+
mock_anthropic_client.messages.create.return_value = {
17+
"role": "assistant",
18+
"content": "Test response from Anthropic"
19+
}
20+
21+
client = Client()
22+
client.configure({
23+
ProviderNames.ANTHROPIC: {
24+
'api_key': 'test_anthropic_api_key'
25+
}
26+
})
27+
28+
messages = [
29+
{"role": "system", "content": "You are a helpful assistant."},
30+
{"role": "user", "content": "Describe cloud computing."}
31+
]
32+
33+
response = client.chat.completions.create(ProviderNames.ANTHROPIC + ":" + "claude-v3", messages)
34+
35+
# Assert that the converse API was called with the correct parameters
36+
mock_anthropic_client.messages.create.assert_called_once_with(
37+
model="claude-v3",
38+
system="You are a helpful assistant.",
39+
messages=[
40+
{"role": "user", "content": "Describe cloud computing."}
41+
]
42+
)
43+
44+
# Check that the response is normalized correctly
45+
self.assertEqual(response['choices'][0]['message']['content'], "Test response from Anthropic")
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
from client import Client
4+
from provider import ProviderNames
5+
6+
class TestAWSBedrockProvider(unittest.TestCase):
7+
8+
@patch('boto3.client')
9+
def test_aws_bedrock_converse_parameters(self, mock_boto_client):
10+
"""Test that correct parameters are passed to AWS Bedrock's converse API."""
11+
# Mocking the Bedrock client
12+
mock_bedrock_client = MagicMock()
13+
mock_boto_client.return_value = mock_bedrock_client
14+
15+
# Mock response from the Bedrock API
16+
mock_bedrock_client.converse.return_value = {
17+
"output": {
18+
"message": {
19+
"content": "Test response from AWS Bedrock"
20+
}
21+
}
22+
}
23+
24+
client = Client()
25+
# No need to call client.configure() as boto3 uses environment variables.
26+
27+
# Call the client for AWS Bedrock
28+
bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3"
29+
messages = [
30+
{"role": "system", "content": "You are a helpful assistant."},
31+
{"role": "user", "content": "Describe cloud computing."}
32+
]
33+
response = client.chat.completions.create(bedrock_model, messages, temperature=0.7, maxTokens=100, top_k=200)
34+
35+
# Assert that the converse API was called with the correct parameters
36+
mock_bedrock_client.converse.assert_called_once_with(
37+
model="claude-v3",
38+
messages=[
39+
{"role": "user", "content": "Describe cloud computing."}
40+
],
41+
system=["You are a helpful assistant."],
42+
inferenceConfig={
43+
'temperature': 0.7,
44+
'maxTokens': 100
45+
},
46+
additionalModelRequestFields={
47+
'top_k': 200
48+
}
49+
)
50+
51+
# Check that the response is normalized correctly
52+
self.assertEqual(response['choices'][0]['message']['content'], "Test response from AWS Bedrock")
53+
54+
if __name__ == '__main__':
55+
unittest.main()

0 commit comments

Comments
 (0)