Skip to content

Commit 120ed16

Browse files
Merge pull request #29 from andrewyng/update-google-provider
Update GoogleProvider
2 parents 8cda46c + dab5be0 commit 120ed16

File tree

5 files changed

+59
-22
lines changed

5 files changed

+59
-22
lines changed

aisuite/provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ProviderNames(str, Enum):
2222
AWS_BEDROCK = "aws-bedrock"
2323
AZURE = "azure"
2424
GROQ = "groq"
25+
GOOGLE = "google"
2526
MISTRAL = "mistral"
2627
OPENAI = "openai"
2728

@@ -40,6 +41,7 @@ class ProviderFactory:
4041
),
4142
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
4243
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
44+
ProviderNames.GOOGLE: ("aisuite.providers.google_provider", "GoogleProvider"),
4345
ProviderNames.MISTRAL: (
4446
"aisuite.providers.mistral_provider",
4547
"MistralProvider",

aisuite/providers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@
55
from .ollama_interface import OllamaInterface
66
from .replicate_interface import ReplicateInterface
77
from .together_interface import TogetherInterface
8-
from .google_interface import GoogleInterface

aisuite/providers/google_interface.py renamed to aisuite/providers/google_provider.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,76 @@
11
"""The interface to Google's Vertex AI."""
22

33
import os
4+
5+
import vertexai
6+
from vertexai.generative_models import GenerativeModel, GenerationConfig
7+
48
from aisuite.framework import ProviderInterface, ChatCompletionResponse
59

610

7-
class GoogleInterface(ProviderInterface):
11+
DEFAULT_TEMPERATURE = 0.7
12+
13+
14+
class GoogleProvider(ProviderInterface):
815
"""Implements the ProviderInterface for interacting with Google's Vertex AI."""
916

10-
def __init__(self):
17+
def __init__(self, **config):
1118
"""Set up the Google AI client with a project ID."""
12-
import vertexai
13-
14-
project_id = os.getenv("GOOGLE_PROJECT_ID")
15-
location = os.getenv("GOOGLE_REGION")
16-
app_creds_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
19+
self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID")
20+
self.location = config.get("region") or os.getenv("GOOGLE_REGION")
21+
self.app_creds_path = config.get("application_credentials") or os.getenv(
22+
"GOOGLE_APPLICATION_CREDENTIALS"
23+
)
1724

18-
if not project_id or not location or not app_creds_path:
25+
if not self.project_id or not self.location or not self.app_creds_path:
1926
raise EnvironmentError(
2027
"Missing one or more required Google environment variables: "
2128
"GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. "
2229
"Please refer to the setup guide: /guides/google.md."
2330
)
2431

25-
vertexai.init(project=project_id, location=location)
32+
vertexai.init(project=self.project_id, location=self.location)
2633

27-
def chat_completion_create(self, messages=None, model=None, temperature=0):
34+
def chat_completions_create(self, model, messages, **kwargs):
2835
"""Request chat completions from the Google AI API.
2936
3037
Args:
3138
----
3239
model (str): Identifies the specific provider/model to use.
3340
messages (list of dict): A list of message objects in chat history.
41+
kwargs (dict): Optional arguments for the Google AI API.
3442
3543
Returns:
3644
-------
3745
The ChatCompletionResponse with the completion result.
3846
3947
"""
40-
from vertexai.generative_models import GenerativeModel, GenerationConfig
4148

49+
# Set the temperature if provided, otherwise use the default
50+
temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE)
51+
52+
# Transform the roles in the messages
4253
transformed_messages = self.transform_roles(messages)
4354

55+
# Convert the messages to the format expected Google
4456
final_message_history = self.convert_openai_to_vertex_ai(
4557
transformed_messages[:-1]
4658
)
59+
60+
# Get the last message from the transformed messages
4761
last_message = transformed_messages[-1]["content"]
4862

63+
# Create the GenerativeModel with the specified model and generation configuration
4964
model = GenerativeModel(
5065
model, generation_config=GenerationConfig(temperature=temperature)
5166
)
5267

68+
# Start a chat with the GenerativeModel and send the last message
5369
chat = model.start_chat(history=final_message_history)
5470
response = chat.send_message(last_message)
55-
return self.convert_response_to_openai_format(response)
71+
72+
# Convert the response to the format expected by the OpenAI API
73+
return self.normalize_response(response)
5674

5775
def convert_openai_to_vertex_ai(self, messages):
5876
"""Convert OpenAI messages to Google AI messages."""
@@ -78,8 +96,8 @@ def transform_roles(self, messages):
7896
message["role"] = role
7997
return messages
8098

81-
def convert_response_to_openai_format(self, response):
82-
"""Convert Google AI response to OpenAI's ChatCompletionResponse format."""
99+
def normalize_response(self, response):
100+
"""Normalize the response from Google AI to match OpenAI's response format."""
83101
openai_response = ChatCompletionResponse()
84102
openai_response.choices[0].message.content = (
85103
response.candidates[0].content.parts[0].text

tests/client/test_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ class TestClient(unittest.TestCase):
1515
@patch(
1616
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
1717
)
18+
@patch("aisuite.providers.google_provider.GoogleProvider.chat_completions_create")
1819
def test_client_chat_completions(
1920
self,
21+
mock_google,
2022
mock_anthropic,
2123
mock_azure,
2224
mock_bedrock,
@@ -31,6 +33,7 @@ def test_client_chat_completions(
3133
mock_anthropic.return_value = "Anthropic Response"
3234
mock_groq.return_value = "Groq Response"
3335
mock_mistral.return_value = "Mistral Response"
36+
mock_google.return_value = "Google Response"
3437

3538
# Provider configurations
3639
provider_configs = {
@@ -50,6 +53,11 @@ def test_client_chat_completions(
5053
ProviderNames.MISTRAL: {
5154
"api_key": "mistral-api-key",
5255
},
56+
ProviderNames.GOOGLE: {
57+
"project_id": "test_google_project_id",
58+
"region": "us-west4",
59+
"application_credentials": "test_google_application_credentials",
60+
},
5361
}
5462

5563
# Initialize the client
@@ -104,6 +112,14 @@ def test_client_chat_completions(
104112
self.assertEqual(mistral_response, "Mistral Response")
105113
mock_mistral.assert_called_once()
106114

115+
# Test Google model
116+
google_model = ProviderNames.GOOGLE + ":" + "google-model"
117+
google_response = client.chat.completions.create(
118+
google_model, messages=messages
119+
)
120+
self.assertEqual(google_response, "Google Response")
121+
mock_google.assert_called_once()
122+
107123
# Test that new instances of Completion are not created each time we make an inference call.
108124
compl_instance = client.chat.completions
109125
next_compl_instance = client.chat.completions

tests/providers/test_google_interface.py renamed to tests/providers/test_google_provider.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from unittest.mock import patch, MagicMock
3-
from aisuite.providers.google_interface import GoogleInterface
3+
from aisuite.providers.google_provider import GoogleProvider
44
from vertexai.generative_models import Content, Part
55

66

@@ -16,7 +16,7 @@ def test_missing_env_vars():
1616
"""Test that an error is raised if required environment variables are missing."""
1717
with patch.dict("os.environ", {}, clear=True):
1818
with pytest.raises(EnvironmentError) as exc_info:
19-
GoogleInterface()
19+
GoogleProvider()
2020
assert "Missing one or more required Google environment variables" in str(
2121
exc_info.value
2222
)
@@ -30,19 +30,21 @@ def test_vertex_interface():
3030
selected_model = "our-favorite-model"
3131
response_text_content = "mocked-text-response-from-model"
3232

33-
interface = GoogleInterface()
33+
interface = GoogleProvider()
3434
mock_response = MagicMock()
3535
mock_response.candidates = [MagicMock()]
3636
mock_response.candidates[0].content.parts[0].text = response_text_content
3737

38-
with patch("vertexai.generative_models.GenerativeModel") as mock_generative_model:
38+
with patch(
39+
"aisuite.providers.google_provider.GenerativeModel"
40+
) as mock_generative_model:
3941
mock_model = MagicMock()
4042
mock_generative_model.return_value = mock_model
4143
mock_chat = MagicMock()
4244
mock_model.start_chat.return_value = mock_chat
4345
mock_chat.send_message.return_value = mock_response
4446

45-
response = interface.chat_completion_create(
47+
response = interface.chat_completions_create(
4648
messages=message_history,
4749
model=selected_model,
4850
temperature=0.7,
@@ -68,7 +70,7 @@ def test_vertex_interface():
6870

6971

7072
def test_convert_openai_to_vertex_ai():
71-
interface = GoogleInterface()
73+
interface = GoogleProvider()
7274
message_history = [{"role": "user", "content": "Hello!"}]
7375
result = interface.convert_openai_to_vertex_ai(message_history)
7476
assert isinstance(result[0], Content)
@@ -79,7 +81,7 @@ def test_convert_openai_to_vertex_ai():
7981

8082

8183
def test_transform_roles():
82-
interface = GoogleInterface()
84+
interface = GoogleProvider()
8385

8486
messages = [
8587
{"role": "system", "content": "Google: system message = 1st user message."},

0 commit comments

Comments
 (0)