Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
from .message import Message
26 changes: 0 additions & 26 deletions aisuite/framework/provider_interface.py

This file was deleted.

12 changes: 8 additions & 4 deletions aisuite/providers/google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
)
import pprint

from aisuite.framework import ProviderInterface, ChatCompletionResponse, Message

from aisuite.framework import ChatCompletionResponse, Message
from aisuite.provider import Provider

DEFAULT_TEMPERATURE = 0.7
ENABLE_DEBUG_MESSAGES = False
Expand Down Expand Up @@ -189,8 +189,8 @@ def convert_response(response) -> ChatCompletionResponse:
return openai_response


class GoogleProvider(ProviderInterface):
"""Implements the ProviderInterface for interacting with Google's Vertex AI."""
class GoogleProvider(Provider):
"""Implements the Provider Interface for interacting with Google's Vertex AI."""

def __init__(self, **config):
"""Set up the Google AI client with a project ID."""
Expand Down Expand Up @@ -229,6 +229,9 @@ def chat_completions_create(self, model, messages, **kwargs):
# Set the temperature if provided, otherwise use the default
temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE)

# Set safety_settings if provided
safety_settings = kwargs.get("safety_settings")

# Convert messages to Vertex AI format
message_history = self.transformer.convert_request(messages)

Expand Down Expand Up @@ -274,6 +277,7 @@ def chat_completions_create(self, model, messages, **kwargs):
model,
generation_config=GenerationConfig(temperature=temperature),
tools=tools,
safety_settings=safety_settings
)

if ENABLE_DEBUG_MESSAGES:
Expand Down
41 changes: 41 additions & 0 deletions guides/google.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,45 @@ response = client.chat.completions.create(
print(response.choices[0].message.content)
```

## Safety Settings

```python
from aisuite import Client

client = Client({
"google":{
"project_id": "project-id",
"region": "us-central1",
}
})

model = "google:gemini-2.0-flash-001"

messages = [{
"role": "user",
"content": "I shouldn't use a public swimming pool"}]

from vertexai.generative_models import (
HarmCategory,
HarmBlockThreshold,
SafetySetting,
)

safety_config = [
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
]

response = client.chat.completions.create( safety_settings=safety_config,
model=model, messages=messages)
print(response.choices[0].message.content)

```

Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md).