Skip to content

Commit

Permalink
Update method
Browse files Browse the repository at this point in the history
  • Loading branch information
standsleeping committed Aug 27, 2024
1 parent 5aa9b06 commit d7448a3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
23 changes: 10 additions & 13 deletions aimodels/providers/google_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def chat_completion_create(self, messages=None, model=None, temperature=0):
"""
from vertexai.generative_models import GenerativeModel, GenerationConfig

transformed_messages = self.transform_roles(
messages=messages,
transformations=[("system", "user"), ("assistant", "model")],
)
transformed_messages = self.transform_roles(messages)

final_message_history = self.convert_openai_to_vertex_ai(
transformed_messages[:-1]
Expand All @@ -69,17 +66,17 @@ def convert_openai_to_vertex_ai(self, messages):
history.append(Content(role=role, parts=parts))
return history

def transform_roles(self, messages, transformations):
def transform_roles(self, messages):
"""Transform the roles in the messages based on the provided transformations."""
transformed_messages = []
openai_roles_to_google_roles = {
"system": "user",
"assistant": "model",
}

for message in messages:
new_message = message.copy()
for from_role, to_role in transformations:
if new_message["role"] == from_role:
new_message["role"] = to_role
break
transformed_messages.append(new_message)
return transformed_messages
if role := openai_roles_to_google_roles.get(message["role"], None):
message["role"] = role
return messages

def convert_response_to_openai_format(self, response):
"""Convert Google AI response to OpenAI's ChatCompletionResponse format."""
Expand Down
4 changes: 1 addition & 3 deletions tests/providers/test_google_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def test_transform_roles():
{"role": "model", "content": "Assistant message 1."},
]

result = interface.transform_roles(
messages, transformations=[("system", "user"), ("assistant", "model")]
)
result = interface.transform_roles(messages)

assert result == expected_output

0 comments on commit d7448a3

Please sign in to comment.