Skip to content

Commit 8bceb35

Browse files
committed
add the groq provider
1 parent d77f312 commit 8bceb35

File tree

5 files changed

+84
-5
lines changed

5 files changed

+84
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
.vscode/
33
__pycache__/
44
env/
5+
.env

aisuite/provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ProviderNames(str, Enum):
2222
AWS_BEDROCK = "aws-bedrock"
2323
ANTHROPIC = "anthropic"
2424
AZURE = "azure"
25+
GROQ = "groq"
2526

2627

2728
class ProviderFactory:
@@ -38,6 +39,7 @@ class ProviderFactory:
3839
"AnthropicProvider",
3940
),
4041
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
42+
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
4143
}
4244

4345
@classmethod

aisuite/providers/groq_provider.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
1+
import os
2+
3+
import groq
14
from aisuite.provider import Provider
25

36

47
class GroqProvider(Provider):
5-
def __init__(self) -> None:
6-
pass
8+
def __init__(self, **config):
9+
"""
10+
Initialize the Groq provider with the given configuration.
11+
Pass the entire configuration dictionary to the Groq client constructor.
12+
"""
13+
# Ensure API key is provided either in config or via environment variable
14+
config.setdefault("api_key", os.getenv("GROQ_API_KEY"))
15+
if not config["api_key"]:
16+
raise ValueError(
17+
" API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable."
18+
)
19+
self.client = groq.Groq(**config)
720

8-
def chat_completions_create(self, model, messages):
9-
raise ValueError("Groq provider not yet implemented.")
21+
def chat_completions_create(self, model, messages, **kwargs):
22+
return self.client.chat.completions.create(
23+
model=model,
24+
messages=messages,
25+
**kwargs # Pass any additional arguments to the Groq API
26+
)

tests/client/test_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
class TestClient(unittest.TestCase):
88

9+
@patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create")
910
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
1011
@patch(
1112
"aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create"
@@ -15,13 +16,14 @@ class TestClient(unittest.TestCase):
1516
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
1617
)
1718
def test_client_chat_completions(
18-
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai
19+
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai, mock_groq
1920
):
2021
# Mock responses from providers
2122
mock_openai.return_value = "OpenAI Response"
2223
mock_bedrock.return_value = "AWS Bedrock Response"
2324
mock_azure.return_value = "Azure Response"
2425
mock_anthropic.return_value = "Anthropic Response"
26+
mock_groq.return_value = "Groq Response"
2527

2628
# Provider configurations
2729
provider_configs = {
@@ -35,6 +37,9 @@ def test_client_chat_completions(
3537
ProviderNames.AZURE: {
3638
"api_key": "azure-api-key",
3739
},
40+
ProviderNames.GROQ: {
41+
"api_key": "groq-api-key",
42+
},
3843
}
3944

4045
# Initialize the client
@@ -61,18 +66,26 @@ def test_client_chat_completions(
6166
self.assertEqual(bedrock_response, "AWS Bedrock Response")
6267
mock_bedrock.assert_called_once()
6368

69+
# Test Azure model
6470
azure_model = ProviderNames.AZURE + ":" + "azure-model"
6571
azure_response = client.chat.completions.create(azure_model, messages=messages)
6672
self.assertEqual(azure_response, "Azure Response")
6773
mock_azure.assert_called_once()
6874

75+
# Test Anthropic model
6976
anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model"
7077
anthropic_response = client.chat.completions.create(
7178
anthropic_model, messages=messages
7279
)
7380
self.assertEqual(anthropic_response, "Anthropic Response")
7481
mock_anthropic.assert_called_once()
7582

83+
# Test Groq model
84+
groq_model = ProviderNames.GROQ + ":" + "groq-model"
85+
groq_response = client.chat.completions.create(groq_model, messages=messages)
86+
self.assertEqual(openai_response, "OpenAI Response")
87+
mock_groq.assert_called_once()
88+
7689
# Test that new instances of Completion are not created each time we make an inference call.
7790
compl_instance = client.chat.completions
7891
next_compl_instance = client.chat.completions

tests/providers/test_groq_provider.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
5+
from aisuite.providers.groq_provider import GroqProvider
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def set_api_key_env_var(monkeypatch):
10+
"""Fixture to set environment variables for tests."""
11+
monkeypatch.setenv("GROQ_API_KEY", "test-api-key")
12+
13+
14+
def test_groq_provider():
15+
"""High-level test that the provider is initialized and chat completions are requested successfully."""
16+
17+
user_greeting = "Hello!"
18+
message_history = [{"role": "user", "content": user_greeting}]
19+
selected_model = "our-favorite-model"
20+
chosen_temperature = 0.75
21+
response_text_content = "mocked-text-response-from-model"
22+
23+
provider = GroqProvider()
24+
mock_response = MagicMock()
25+
mock_response.choices = [MagicMock()]
26+
mock_response.choices[0].message = MagicMock()
27+
mock_response.choices[0].message.content = response_text_content
28+
29+
with patch.object(
30+
provider.client.chat.completions,
31+
"create",
32+
return_value=mock_response,
33+
) as mock_create:
34+
response = provider.chat_completions_create(
35+
messages=message_history,
36+
model=selected_model,
37+
temperature=chosen_temperature,
38+
)
39+
40+
mock_create.assert_called_with(
41+
messages=message_history,
42+
model=selected_model,
43+
temperature=chosen_temperature,
44+
)
45+
46+
assert response.choices[0].message.content == response_text_content

0 commit comments

Comments
 (0)