diff --git a/aimodels/providers/google_interface.py b/aimodels/providers/google_interface.py index bfe13623..9e4fd2ab 100644 --- a/aimodels/providers/google_interface.py +++ b/aimodels/providers/google_interface.py @@ -39,10 +39,7 @@ def chat_completion_create(self, messages=None, model=None, temperature=0): """ from vertexai.generative_models import GenerativeModel, GenerationConfig - transformed_messages = self.transform_roles( - messages=messages, - transformations=[("system", "user"), ("assistant", "model")], - ) + transformed_messages = self.transform_roles(messages) final_message_history = self.convert_openai_to_vertex_ai( transformed_messages[:-1] @@ -69,17 +66,17 @@ def convert_openai_to_vertex_ai(self, messages): history.append(Content(role=role, parts=parts)) return history - def transform_roles(self, messages, transformations): + def transform_roles(self, messages): """Transform the roles in the messages based on the provided transformations.""" - transformed_messages = [] + openai_roles_to_google_roles = { + "system": "user", + "assistant": "model", + } + for message in messages: - new_message = message.copy() - for from_role, to_role in transformations: - if new_message["role"] == from_role: - new_message["role"] = to_role - break - transformed_messages.append(new_message) - return transformed_messages + if role := openai_roles_to_google_roles.get(message["role"], None): + message["role"] = role + return messages def convert_response_to_openai_format(self, response): """Convert Google AI response to OpenAI's ChatCompletionResponse format.""" diff --git a/tests/providers/test_google_interface.py b/tests/providers/test_google_interface.py index bc91fdba..ee3871f7 100644 --- a/tests/providers/test_google_interface.py +++ b/tests/providers/test_google_interface.py @@ -93,8 +93,6 @@ def test_transform_roles(): {"role": "model", "content": "Assistant message 1."}, ] - result = interface.transform_roles( - messages, transformations=[("system", "user"), ("assistant", "model")] - ) + result = interface.transform_roles(messages) assert result == expected_output