diff --git a/.gitignore b/.gitignore index e1084c98..e583ea06 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ .vscode/ __pycache__/ env/ +.env diff --git a/aisuite/provider.py b/aisuite/provider.py index c28e096a..46e8f493 100644 --- a/aisuite/provider.py +++ b/aisuite/provider.py @@ -22,6 +22,7 @@ class ProviderNames(str, Enum): AWS_BEDROCK = "aws-bedrock" ANTHROPIC = "anthropic" AZURE = "azure" + GROQ = "groq" class ProviderFactory: @@ -38,6 +39,7 @@ class ProviderFactory: "AnthropicProvider", ), ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"), + ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"), } @classmethod diff --git a/aisuite/providers/groq_provider.py b/aisuite/providers/groq_provider.py index 6ddde342..752ea96d 100644 --- a/aisuite/providers/groq_provider.py +++ b/aisuite/providers/groq_provider.py @@ -1,9 +1,26 @@ +import os + +import groq from aisuite.provider import Provider class GroqProvider(Provider): - def __init__(self) -> None: - pass + def __init__(self, **config): + """ + Initialize the Groq provider with the given configuration. + Pass the entire configuration dictionary to the Groq client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("GROQ_API_KEY")) + if not config["api_key"]: + raise ValueError( + " API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + ) + self.client = groq.Groq(**config) - def chat_completions_create(self, model, messages): - raise ValueError("Groq provider not yet implemented.") + def chat_completions_create(self, model, messages, **kwargs): + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Groq API + ) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 7884a7fe..151fc719 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -6,6 +6,7 @@ class TestClient(unittest.TestCase): + @patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create") @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") @patch( "aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" @@ -15,13 +16,14 @@ class TestClient(unittest.TestCase): "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" ) def test_client_chat_completions( - self, mock_anthropic, mock_azure, mock_bedrock, mock_openai + self, mock_anthropic, mock_azure, mock_bedrock, mock_openai, mock_groq ): # Mock responses from providers mock_openai.return_value = "OpenAI Response" mock_bedrock.return_value = "AWS Bedrock Response" mock_azure.return_value = "Azure Response" mock_anthropic.return_value = "Anthropic Response" + mock_groq.return_value = "Groq Response" # Provider configurations provider_configs = { @@ -35,6 +37,9 @@ def test_client_chat_completions( ProviderNames.AZURE: { "api_key": "azure-api-key", }, + ProviderNames.GROQ: { + "api_key": "groq-api-key", + }, } # Initialize the client @@ -61,11 +66,13 @@ def test_client_chat_completions( self.assertEqual(bedrock_response, "AWS Bedrock Response") mock_bedrock.assert_called_once() + # Test Azure model azure_model = ProviderNames.AZURE + ":" + "azure-model" azure_response = client.chat.completions.create(azure_model, messages=messages) self.assertEqual(azure_response, "Azure Response") mock_azure.assert_called_once() + # Test Anthropic model anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" anthropic_response = client.chat.completions.create( anthropic_model, messages=messages @@ -73,6 +80,12 @@ def test_client_chat_completions( self.assertEqual(anthropic_response, "Anthropic Response") mock_anthropic.assert_called_once() + # Test Groq model + groq_model = ProviderNames.GROQ + ":" + "groq-model" + groq_response = client.chat.completions.create(groq_model, messages=messages) + self.assertEqual(openai_response, "OpenAI Response") + mock_groq.assert_called_once() + # Test that new instances of Completion are not created each time we make an inference call. compl_instance = client.chat.completions next_compl_instance = client.chat.completions diff --git a/tests/providers/test_groq_provider.py b/tests/providers/test_groq_provider.py new file mode 100644 index 00000000..94e953fc --- /dev/null +++ b/tests/providers/test_groq_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.groq_provider import GroqProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("GROQ_API_KEY", "test-api-key") + + +def test_groq_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = GroqProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content