1
1
"""The interface to Google's Vertex AI."""
2
2
3
3
import os
4
+
5
+ import vertexai
6
+ from vertexai .generative_models import GenerativeModel , GenerationConfig
7
+
4
8
from aisuite .framework import ProviderInterface , ChatCompletionResponse
5
9
6
10
7
- class GoogleInterface (ProviderInterface ):
11
+ DEFAULT_TEMPERATURE = 0.7
12
+
13
+
14
+ class GoogleProvider (ProviderInterface ):
8
15
"""Implements the ProviderInterface for interacting with Google's Vertex AI."""
9
16
10
- def __init__ (self ):
17
+ def __init__ (self , ** config ):
11
18
"""Set up the Google AI client with a project ID."""
12
- import vertexai
13
-
14
- project_id = os . getenv ( "GOOGLE_PROJECT_ID" )
15
- location = os . getenv ( "GOOGLE_REGION" )
16
- app_creds_path = os . getenv ( "GOOGLE_APPLICATION_CREDENTIALS" )
19
+ self . project_id = config . get ( "project_id" ) or os . getenv ( "GOOGLE_PROJECT_ID" )
20
+ self . location = config . get ( "region" ) or os . getenv ( "GOOGLE_REGION" )
21
+ self . app_creds_path = config . get ( "application_credentials" ) or os . getenv (
22
+ "GOOGLE_APPLICATION_CREDENTIALS"
23
+ )
17
24
18
- if not project_id or not location or not app_creds_path :
25
+ if not self . project_id or not self . location or not self . app_creds_path :
19
26
raise EnvironmentError (
20
27
"Missing one or more required Google environment variables: "
21
28
"GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. "
22
29
"Please refer to the setup guide: /guides/google.md."
23
30
)
24
31
25
- vertexai .init (project = project_id , location = location )
32
+ vertexai .init (project = self . project_id , location = self . location )
26
33
27
- def chat_completion_create (self , messages = None , model = None , temperature = 0 ):
34
+ def chat_completions_create (self , model , messages , ** kwargs ):
28
35
"""Request chat completions from the Google AI API.
29
36
30
37
Args:
31
38
----
32
39
model (str): Identifies the specific provider/model to use.
33
40
messages (list of dict): A list of message objects in chat history.
41
+ kwargs (dict): Optional arguments for the Google AI API.
34
42
35
43
Returns:
36
44
-------
37
45
The ChatCompletionResponse with the completion result.
38
46
39
47
"""
40
- from vertexai .generative_models import GenerativeModel , GenerationConfig
41
48
49
+ # Set the temperature if provided, otherwise use the default
50
+ temperature = kwargs .get ("temperature" , DEFAULT_TEMPERATURE )
51
+
52
+ # Transform the roles in the messages
42
53
transformed_messages = self .transform_roles (messages )
43
54
55
+ # Convert the messages to the format expected Google
44
56
final_message_history = self .convert_openai_to_vertex_ai (
45
57
transformed_messages [:- 1 ]
46
58
)
59
+
60
+ # Get the last message from the transformed messages
47
61
last_message = transformed_messages [- 1 ]["content" ]
48
62
63
+ # Create the GenerativeModel with the specified model and generation configuration
49
64
model = GenerativeModel (
50
65
model , generation_config = GenerationConfig (temperature = temperature )
51
66
)
52
67
68
+ # Start a chat with the GenerativeModel and send the last message
53
69
chat = model .start_chat (history = final_message_history )
54
70
response = chat .send_message (last_message )
55
- return self .convert_response_to_openai_format (response )
71
+
72
+ # Convert the response to the format expected by the OpenAI API
73
+ return self .normalize_response (response )
56
74
57
75
def convert_openai_to_vertex_ai (self , messages ):
58
76
"""Convert OpenAI messages to Google AI messages."""
@@ -78,8 +96,8 @@ def transform_roles(self, messages):
78
96
message ["role" ] = role
79
97
return messages
80
98
81
- def convert_response_to_openai_format (self , response ):
82
- """Convert Google AI response to OpenAI's ChatCompletionResponse format."""
99
+ def normalize_response (self , response ):
100
+ """Normalize the response from Google AI to match OpenAI's response format."""
83
101
openai_response = ChatCompletionResponse ()
84
102
openai_response .choices [0 ].message .content = (
85
103
response .candidates [0 ].content .parts [0 ].text
0 commit comments