Skip to content

Commit

Permalink
Merge pull request #28 from andrewyng/migrate-groq-to-new-format
Browse files Browse the repository at this point in the history
  • Loading branch information
ksolo authored Sep 13, 2024
2 parents d77f312 + 8bceb35 commit 6247876
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
.vscode/
__pycache__/
env/
.env
2 changes: 2 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ProviderNames(str, Enum):
AWS_BEDROCK = "aws-bedrock"
ANTHROPIC = "anthropic"
AZURE = "azure"
GROQ = "groq"


class ProviderFactory:
Expand All @@ -38,6 +39,7 @@ class ProviderFactory:
"AnthropicProvider",
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
}

@classmethod
Expand Down
25 changes: 21 additions & 4 deletions aisuite/providers/groq_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
import os

import groq
from aisuite.provider import Provider


class GroqProvider(Provider):
def __init__(self) -> None:
pass
def __init__(self, **config):
"""
Initialize the Groq provider with the given configuration.
Pass the entire configuration dictionary to the Groq client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("GROQ_API_KEY"))
if not config["api_key"]:
raise ValueError(
" API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable."
)
self.client = groq.Groq(**config)

def chat_completions_create(self, model, messages):
raise ValueError("Groq provider not yet implemented.")
def chat_completions_create(self, model, messages, **kwargs):
return self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the Groq API
)
15 changes: 14 additions & 1 deletion tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class TestClient(unittest.TestCase):

@patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create")
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
@patch(
"aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create"
Expand All @@ -15,13 +16,14 @@ class TestClient(unittest.TestCase):
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
)
def test_client_chat_completions(
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai
self, mock_anthropic, mock_azure, mock_bedrock, mock_openai, mock_groq
):
# Mock responses from providers
mock_openai.return_value = "OpenAI Response"
mock_bedrock.return_value = "AWS Bedrock Response"
mock_azure.return_value = "Azure Response"
mock_anthropic.return_value = "Anthropic Response"
mock_groq.return_value = "Groq Response"

# Provider configurations
provider_configs = {
Expand All @@ -35,6 +37,9 @@ def test_client_chat_completions(
ProviderNames.AZURE: {
"api_key": "azure-api-key",
},
ProviderNames.GROQ: {
"api_key": "groq-api-key",
},
}

# Initialize the client
Expand All @@ -61,18 +66,26 @@ def test_client_chat_completions(
self.assertEqual(bedrock_response, "AWS Bedrock Response")
mock_bedrock.assert_called_once()

# Test Azure model
azure_model = ProviderNames.AZURE + ":" + "azure-model"
azure_response = client.chat.completions.create(azure_model, messages=messages)
self.assertEqual(azure_response, "Azure Response")
mock_azure.assert_called_once()

# Test Anthropic model
anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model"
anthropic_response = client.chat.completions.create(
anthropic_model, messages=messages
)
self.assertEqual(anthropic_response, "Anthropic Response")
mock_anthropic.assert_called_once()

# Test Groq model
groq_model = ProviderNames.GROQ + ":" + "groq-model"
groq_response = client.chat.completions.create(groq_model, messages=messages)
self.assertEqual(openai_response, "OpenAI Response")
mock_groq.assert_called_once()

# Test that new instances of Completion are not created each time we make an inference call.
compl_instance = client.chat.completions
next_compl_instance = client.chat.completions
Expand Down
46 changes: 46 additions & 0 deletions tests/providers/test_groq_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from unittest.mock import MagicMock, patch

import pytest

from aisuite.providers.groq_provider import GroqProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("GROQ_API_KEY", "test-api-key")


def test_groq_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = GroqProvider()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content

0 comments on commit 6247876

Please sign in to comment.