Skip to content

Commit 2f89367

Browse files
Merge pull request #6 from andrewyng/start-testing
Start testing
2 parents e91e5fc + 1ef852e commit 2f89367

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

.github/workflows/run_pytest.yml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: Lint
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
build_and_test:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
matrix:
10+
python-version: [ "3.10", "3.11", "3.12" ]
11+
steps:
12+
- uses: actions/checkout@v4
13+
- name: Set up Python ${{ matrix.python-version }}
14+
uses: actions/setup-python@v5
15+
with:
16+
python-version: ${{ matrix.python-version }}
17+
- name: Install dependencies
18+
run: |
19+
python -m pip install --upgrade pip
20+
pip install poetry
21+
poetry install
22+
- name: Test with pytest
23+
run: poetry run pytest
24+

tests/test_anthropic_interface.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock
3+
from aimodels.providers.anthropic_interface import AnthropicInterface
4+
5+
6+
@pytest.fixture(autouse=True)
7+
def set_api_key_env_var(monkeypatch):
8+
"""Fixture to set environment variables for tests."""
9+
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
10+
11+
12+
def test_anthropic_interface():
13+
"""High-level test that the interface is initialized and chat completions are requested successfully."""
14+
15+
user_greeting = "Hello!"
16+
message_history = [{"role": "user", "content": user_greeting}]
17+
selected_model = "our-favorite-model"
18+
chosen_temperature = 0.75
19+
response_text_content = "mocked-text-response-from-model"
20+
21+
interface = AnthropicInterface()
22+
mock_response = MagicMock()
23+
mock_response.content = [MagicMock()]
24+
mock_response.content[0].text = response_text_content
25+
26+
with patch.object(
27+
interface.anthropic_client.messages, "create", return_value=mock_response
28+
) as mock_create:
29+
response = interface.chat_completion_create(
30+
messages=message_history,
31+
model=selected_model,
32+
temperature=chosen_temperature,
33+
)
34+
35+
transformed_message_history = [
36+
{"role": "user", "content": [{"type": "text", "text": user_greeting}]},
37+
]
38+
39+
mock_create.assert_called_with(
40+
messages=transformed_message_history,
41+
model=selected_model,
42+
temperature=chosen_temperature,
43+
max_tokens=4096,
44+
)
45+
46+
assert response.choices[0].message.content == response_text_content

tests/test_multi_fm_client.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
from aimodels.client.multi_fm_client import MultiFMClient, AnthropicInterface
3+
4+
5+
def test_get_provider_interface_with_new_instance():
6+
"""Test that get_provider_interface creates a new instance of the interface."""
7+
client = MultiFMClient()
8+
interface, model_name = client.get_provider_interface("anthropic:some-model:v1")
9+
assert isinstance(interface, AnthropicInterface)
10+
assert model_name == "some-model:v1"
11+
assert client.all_interfaces["anthropic"] == interface
12+
13+
14+
def test_get_provider_interface_with_existing_instance():
15+
"""Test that get_provider_interface returns an existing instance of the interface, if already created."""
16+
client = MultiFMClient()
17+
18+
# New interface instance
19+
new_instance, _ = client.get_provider_interface("anthropic:some-model:v2")
20+
21+
# Call twice, get same instance back
22+
same_instance, _ = client.get_provider_interface("anthropic:some-model:v2")
23+
24+
assert new_instance is same_instance
25+
26+
27+
def test_get_provider_interface_with_invalid_format():
28+
client = MultiFMClient()
29+
30+
with pytest.raises(ValueError) as exc_info:
31+
client.get_provider_interface("invalid-model-no-colon")
32+
33+
assert "Expected ':' in model identifier" in str(exc_info.value)
34+
35+
36+
def test_get_provider_interface_with_unknown_interface():
37+
client = MultiFMClient()
38+
39+
with pytest.raises(Exception) as exc_info:
40+
client.get_provider_interface("unknown-interface:some-model")
41+
42+
assert (
43+
"Could not find factory to create interface for provider 'unknown-interface'"
44+
in str(exc_info.value)
45+
)

0 commit comments

Comments
 (0)