Skip to content

Commit 9c84264

Browse files
committed
Test multiple providers.
This will be invoked by the GitHub workflow each time a release is cut. This is part of the pre-work for creating release automation.
1 parent 271af0d commit 9c84264

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ build-backend = "poetry.core.masonry.api"
6363

6464
[tool.pytest.ini_options]
6565
testpaths="tests"
66-
addopts=[
67-
"--cov=aisuite",
68-
"--cov-report=term-missing"
66+
markers = [
67+
"integration: marks tests as integration tests that interact with external services",
6968
]

tests/client/test_prerelease.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Run this test before releasing a new version.
2+
# It will test all the models in the client.
3+
4+
import pytest
5+
import aisuite as ai
6+
from typing import List, Dict
7+
from dotenv import load_dotenv, find_dotenv
8+
9+
10+
def setup_client() -> ai.Client:
11+
"""Initialize the AI client with environment variables."""
12+
load_dotenv(find_dotenv())
13+
return ai.Client()
14+
15+
16+
def get_test_models() -> List[str]:
17+
"""Return a list of model identifiers to test."""
18+
return [
19+
"anthropic:claude-3-5-sonnet-20240620",
20+
"aws:meta.llama3-1-8b-instruct-v1:0",
21+
"huggingface:mistralai/Mistral-7B-Instruct-v0.3",
22+
"groq:llama3-8b-8192",
23+
"mistral:open-mistral-7b",
24+
"openai:gpt-3.5-turbo",
25+
"cohere:command-r-plus-08-2024",
26+
]
27+
28+
29+
def get_test_messages() -> List[Dict[str, str]]:
30+
"""Return the test messages to send to each model."""
31+
return [
32+
{
33+
"role": "system",
34+
"content": "Respond in Pirate English. Always try to include the phrase - No rum No fun.",
35+
},
36+
{"role": "user", "content": "Tell me a joke about Captain Jack Sparrow"},
37+
]
38+
39+
40+
@pytest.mark.integration
41+
@pytest.mark.parametrize("model_id", get_test_models())
42+
def test_model_pirate_response(model_id: str):
43+
"""
44+
Test that each model responds appropriately to the pirate prompt.
45+
46+
Args:
47+
model_id: The provider:model identifier to test
48+
"""
49+
client = setup_client()
50+
messages = get_test_messages()
51+
52+
try:
53+
response = client.chat.completions.create(
54+
model=model_id, messages=messages, temperature=0.75
55+
)
56+
57+
content = response.choices[0].message.content.lower()
58+
59+
# Check if either version of the required phrase is present
60+
assert any(
61+
phrase in content for phrase in ["no rum no fun", "no rum, no fun"]
62+
), f"Model {model_id} did not include required phrase 'No rum No fun'"
63+
64+
assert len(content) > 0, f"Model {model_id} returned empty response"
65+
assert isinstance(
66+
content, str
67+
), f"Model {model_id} returned non-string response"
68+
69+
except Exception as e:
70+
pytest.fail(f"Error testing model {model_id}: {str(e)}")
71+
72+
73+
if __name__ == "__main__":
74+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)