Skip to content

Commit

Permalink
Add the Google/Vertex provider
Browse files Browse the repository at this point in the history
  • Loading branch information
standsleeping committed Aug 24, 2024
1 parent 75120e9 commit 906078b
Show file tree
Hide file tree
Showing 7 changed files with 666 additions and 3 deletions.
2 changes: 2 additions & 0 deletions aimodels/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
OpenAIInterface,
ReplicateInterface,
TogetherInterface,
VertexInterface,
)


Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self):
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
"together": TogetherInterface,
"vertex": VertexInterface,
}

def get_provider_interface(self, model):
Expand Down
1 change: 1 addition & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .vertex_interface import VertexInterface
77 changes: 77 additions & 0 deletions aimodels/providers/vertex_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""The interface to Vertex AI."""

import os
from aimodels.framework import ProviderInterface, ChatCompletionResponse


class VertexInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Vertex AI."""

def __init__(self):
"""Set up the Vertex AI client with a project ID."""
import vertexai

vertexai.init(
project=os.getenv("VERTEX_PROJECT_ID"), location=os.getenv("VERTEX_REGION")
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Vertex AI API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
Returns:
-------
The ChatCompletionResponse with the completion result.
"""
from vertexai.generative_models import GenerativeModel, GenerationConfig

without_system_messages = self.transform_roles(
messages=messages, from_role="system", to_role="user"
)

with_model_roles = self.transform_roles(
messages=without_system_messages, from_role="assistant", to_role="model"
)

final_message_history = self.convert_openai_to_vertex_ai(with_model_roles[:-1])
last_message = with_model_roles[-1]["content"]

model = GenerativeModel(
model, generation_config=GenerationConfig(temperature=temperature)
)

chat = model.start_chat(history=final_message_history)
response = chat.send_message(last_message)
return self.convert_response_to_openai_format(response)

def convert_openai_to_vertex_ai(self, messages):
"""Convert OpenAI messages to Vertex AI messages."""
from vertexai.generative_models import Content, Part

history = []
for message in messages:
role = message["role"]
content = message["content"]
parts = [Part.from_text(content)]
history.append(Content(role=role, parts=parts))
return history

def transform_roles(self, messages, from_role, to_role):
"""Transform the roles in the messages to the desired role."""
for message in messages:
if message["role"] == from_role:
message["role"] = to_role
return messages

def convert_response_to_openai_format(self, response):
"""Convert Vertex AI response to OpenAI's ChatCompletionResponse format."""
openai_response = ChatCompletionResponse()
openai_response.choices[0].message.content = (
response.candidates[0].content.parts[0].text
)
return openai_response
92 changes: 92 additions & 0 deletions guides/vertex.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Vertex AI

To use Vertex AI with the `aimodels` library, you'll first need to create a Google Cloud account and set up your environment to work with Google Cloud.

## Create a Google Cloud Account and Project

Google Cloud provides in-depth [documentation](https://cloud.google.com/vertex-ai/docs/start/cloud-environment) on getting started with their platform, but here are the basic steps:

### Create your account.

Visit [Google Cloud](https://cloud.google.com/free) and follow the instructions for registering a new account. If you already have an account with Google Cloud, sign in and skip to the next step.

### Create a new project and enable billing.

Once you have an account, you can create a new project. Visit the [project selector page](https://console.cloud.google.com/projectselector2/home/dashboard) and click the "New Project" button. Give your project a name and click "Create Project." Your project will be created and you will be redirected to the project dashboard.

Now that you have a project, you'll need to enable billing. Visit the [how-to page](https://cloud.google.com/billing/docs/how-to/verify-billing-enabled#confirm_billing_is_enabled_on_a_project) for billing enablement instructions.

### Set your project ID in an environment variable.

Set the `VERTEX_PROJECT_ID` environment variable to the ID of your project. You can find the Project ID by visiting the project dashboard in the "Project Info" section toward the top of the page.

### Set your preferred region in an environment variable.

Set the `VERTEX_REGION` environment variable to the ID of your project. You can find the Project ID by visiting the project dashboard in the "Project Info" section toward the top of the page.

## Create a Service Account For API Access

Because `aimodels` needs to authenticate with Google Cloud to access the Vertex AI API, you'll need to create a service account and set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of a JSON file containing the service account's credentials, which you can download from the Google Cloud Console.

This is documented [here](https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to), and the basic steps are as follows:

1. Visit the [service accounts page](https://console.cloud.google.com/iam-admin/serviceaccounts) in the Google Cloud Console.
2. Click the "+ Create Service Account" button toward the top of the page.
3. Follow the steps for naming your service account and granting access to the project.
4. Click "Done" to create the service account.
5. Now, click the "Keys" tab towards the top of the page.
6. Click the "Add Key" menu, then select "Create New Key."
6. Choose "JSON" as the key type, and click "Create."
7. Move this file to a location on your file system like your home directory.
8. Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the JSON file.

## Double check your environment is configured correctly.

At this point, you should have three environment variables set to ensure your environment is set up correctly:

- `VERTEX_PROJECT_ID`
- `VERTEX_REGION`
- `GOOGLE_APPLICATION_CREDENTIALS`

Once these are set, you are ready to write some code and send a chat completion request.

## Create a chat completion.

With your account and service account set up, you can send a chat completion request.

Export the environment variables:

```shell
export VERTEX_PROJECT_ID="your-project-id"
export VERTEX_REGION="your-region"
export GOOGLE_APPLICATION_CREDENTIALS="path/to/your/service-account-file.json"
```

Install the Vertex AI SDK:

```shell
pip install vertexai
```

In your code:

```python
import aimodels as ai
client = ai.Client()

model="vertex:gemini-1.5-pro-001"

messages = [
{"role": "system", "content": "Respond in Pirate English."},
{"role": "user", "content": "Tell me a joke."},
]

response = client.chat.completions.create(
model=model,
messages=messages,
)

print(response.choices[0].message.content)
```

Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md).
Loading

0 comments on commit 906078b

Please sign in to comment.