Skip to content

Commit 5a92ac3

Browse files
ojaffeJunShern
andauthored
Add Gemini Solver (#1503)
Adds a solver for Gemini 1.5 Pro. Stacked on #1501 and #1482. Using the solver requires the `GEMINI_API_KEY` environment variable Test with: ``` oaieval generation/direct/gemini-pro bugged_tools ``` --------- Co-authored-by: Chan Jun Shern <[email protected]>
1 parent 150dcb9 commit 5a92ac3

File tree

5 files changed

+308
-1
lines changed

5 files changed

+308
-1
lines changed

evals/registry/solvers/gemini.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
# ------------------
3+
# gemini-pro
4+
# ------------------
5+
6+
# generation tasks
7+
8+
generation/direct/gemini-pro:
9+
class: evals.solvers.providers.google.gemini_solver:GeminiSolver
10+
args:
11+
model_name: gemini-pro
12+
13+
generation/cot/gemini-pro:
14+
class: evals.solvers.nested.cot_solver:CoTSolver
15+
args:
16+
cot_solver:
17+
class: evals.solvers.providers.google.gemini_solver:GeminiSolver
18+
args:
19+
model_name: gemini-pro
20+
extract_solver:
21+
class: evals.solvers.providers.google.gemini_solver:GeminiSolver
22+
args:
23+
model_name: gemini-pro
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import copy
2+
import os
3+
from dataclasses import asdict, dataclass
4+
from typing import Any, Dict, Union
5+
6+
import google.api_core.exceptions
7+
import google.generativeai as genai
8+
from google.generativeai.client import get_default_generative_client
9+
10+
from evals.record import record_sampling
11+
from evals.solvers.solver import Solver, SolverResult
12+
from evals.task_state import Message, TaskState
13+
from evals.utils.api_utils import create_retrying
14+
15+
# Load API key from environment variable
16+
API_KEY = os.environ.get("GEMINI_API_KEY")
17+
genai.configure(api_key=API_KEY)
18+
19+
SAFETY_SETTINGS = [
20+
{
21+
"category": "HARM_CATEGORY_HARASSMENT",
22+
"threshold": "BLOCK_NONE",
23+
},
24+
{
25+
"category": "HARM_CATEGORY_HATE_SPEECH",
26+
"threshold": "BLOCK_NONE",
27+
},
28+
{
29+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
30+
"threshold": "BLOCK_NONE",
31+
},
32+
{
33+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
34+
"threshold": "BLOCK_NONE",
35+
},
36+
]
37+
GEMINI_RETRY_EXCEPTIONS = (
38+
google.api_core.exceptions.RetryError,
39+
google.api_core.exceptions.TooManyRequests,
40+
google.api_core.exceptions.ResourceExhausted,
41+
)
42+
43+
44+
# TODO: Could we just use google's own types?
45+
# e.g. google.generativeai.types.content_types.ContentType
46+
@dataclass
47+
class GoogleMessage:
48+
role: str
49+
parts: list[str]
50+
51+
def to_dict(self):
52+
return asdict(self)
53+
54+
@staticmethod
55+
def from_evals_message(msg: Message):
56+
valid_roles = {"user", "model"}
57+
to_google_role = {
58+
"system": "user", # Google doesn't have a "system" role
59+
"user": "user",
60+
"assistant": "model",
61+
}
62+
gmsg = GoogleMessage(
63+
role=to_google_role.get(msg.role, msg.role),
64+
parts=[msg.content],
65+
)
66+
assert gmsg.role in valid_roles, f"Invalid role: {gmsg.role}"
67+
return gmsg
68+
69+
70+
class GeminiSolver(Solver):
71+
"""
72+
A solver class that uses Google's Gemini API to generate responses.
73+
"""
74+
75+
def __init__(
76+
self,
77+
model_name: str,
78+
generation_config: Dict[str, Any] = {},
79+
postprocessors: list[str] = [],
80+
registry: Any = None,
81+
):
82+
super().__init__(postprocessors=postprocessors)
83+
84+
self.model_name = model_name
85+
self.gen_config = genai.GenerationConfig(**generation_config)
86+
87+
# We manually define the client. This is normally defined automatically when calling
88+
# the API, but it isn't thread-safe, so we anticipate its creation here
89+
self.glm_client = get_default_generative_client()
90+
91+
@property
92+
def model(self) -> str:
93+
return self.model_name
94+
95+
def _solve(
96+
self,
97+
task_state: TaskState,
98+
**kwargs,
99+
) -> SolverResult:
100+
msgs = [
101+
Message(role="user", content=task_state.task_description),
102+
] + task_state.messages
103+
gmsgs = self._convert_msgs_to_google_format(msgs)
104+
gmsgs = [msg.to_dict() for msg in gmsgs]
105+
try:
106+
glm_model = genai.GenerativeModel(model_name=self.model_name)
107+
glm_model._client = self.glm_client
108+
109+
gen_content_resp = create_retrying(
110+
glm_model.generate_content,
111+
retry_exceptions=GEMINI_RETRY_EXCEPTIONS,
112+
**{
113+
"contents": gmsgs,
114+
"generation_config": self.gen_config,
115+
"safety_settings": SAFETY_SETTINGS,
116+
},
117+
)
118+
if gen_content_resp.prompt_feedback.block_reason:
119+
# Blocked by safety filters
120+
solver_result = SolverResult(
121+
str(gen_content_resp.prompt_feedback),
122+
error=gen_content_resp.prompt_feedback,
123+
)
124+
else:
125+
# Get text response
126+
solver_result = SolverResult(
127+
gen_content_resp.text,
128+
error=gen_content_resp.prompt_feedback,
129+
)
130+
except (google.api_core.exceptions.GoogleAPIError,) as e:
131+
solver_result = SolverResult(
132+
e.message,
133+
error=e,
134+
)
135+
except ValueError as e:
136+
# TODO: Why does this error ever occur and how can we handle it better?
137+
# (See google/generativeai/types/generation_types.py for the triggers)
138+
known_errors = [
139+
"The `response.text` quick accessor",
140+
"The `response.parts` quick accessor",
141+
]
142+
if any(err in str(e) for err in known_errors):
143+
solver_result = SolverResult(
144+
str(e),
145+
error=e,
146+
)
147+
else:
148+
raise e
149+
150+
record_sampling(
151+
prompt=msgs,
152+
sampled=[solver_result.output],
153+
model=self.model,
154+
)
155+
return solver_result
156+
157+
@staticmethod
158+
def _convert_msgs_to_google_format(msgs: list[Message]) -> list[GoogleMessage]:
159+
"""
160+
Gemini API requires that the message list has
161+
- Roles as 'user' or 'model'
162+
- Alternating 'user' and 'model' messages
163+
- Ends with a 'user' message
164+
"""
165+
# Enforce valid roles
166+
gmsgs = []
167+
for msg in msgs:
168+
gmsg = GoogleMessage.from_evals_message(msg)
169+
gmsgs.append(gmsg)
170+
assert gmsg.role in {"user", "model"}, f"Invalid role: {gmsg.role}"
171+
172+
# Enforce alternating messages
173+
# e.g. [user1, user2, model1, user3] -> [user12, model1, user3]
174+
std_msgs = []
175+
for msg in gmsgs:
176+
if len(std_msgs) > 0 and msg.role == std_msgs[-1].role:
177+
# Merge consecutive messages from the same role
178+
std_msgs[-1].parts.extend(msg.parts)
179+
# The API seems to expect a single-element list of strings (???) so we join the
180+
# parts into a list containing a single string
181+
std_msgs[-1].parts = ["\n".join(std_msgs[-1].parts)]
182+
else:
183+
# Proceed as normal
184+
std_msgs.append(msg)
185+
186+
# Enforce last message is from the user
187+
assert std_msgs[-1].role == "user", "Last message must be from the user"
188+
return std_msgs
189+
190+
@property
191+
def name(self) -> str:
192+
return self.model
193+
194+
@property
195+
def model_version(self) -> Union[str, dict]:
196+
return self.model
197+
198+
def __deepcopy__(self, memo):
199+
"""
200+
Deepcopy everything except for self.glm_client, which is instead shared across all copies
201+
"""
202+
cls = self.__class__
203+
result = cls.__new__(cls)
204+
205+
memo[id(self)] = result
206+
for k, v in self.__dict__.items():
207+
if k != "glm_client":
208+
setattr(result, k, copy.deepcopy(v, memo))
209+
210+
result.glm_client = self.glm_client
211+
return result
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
3+
import pytest
4+
5+
from evals.record import DummyRecorder
6+
from evals.solvers.providers.google.gemini_solver import GeminiSolver, GoogleMessage
7+
from evals.task_state import Message, TaskState
8+
9+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
10+
MODEL_NAME = "gemini-pro"
11+
12+
13+
@pytest.fixture
14+
def dummy_recorder():
15+
recorder = DummyRecorder(None) # type: ignore
16+
with recorder.as_default_recorder("x"):
17+
yield recorder
18+
19+
20+
@pytest.fixture
21+
def gemini_solver():
22+
os.environ["EVALS_SEQUENTIAL"] = "1" # TODO: Remove after fixing threading issue
23+
solver = GeminiSolver(
24+
model_name=MODEL_NAME,
25+
)
26+
return solver
27+
28+
29+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="API tests are wasteful to run on every commit.")
30+
def test_solver(dummy_recorder, gemini_solver):
31+
"""
32+
Test that the solver generates a response coherent with the message history
33+
while following the instructions from the task description.
34+
"""
35+
solver = gemini_solver
36+
37+
answer = "John Doe"
38+
task_state = TaskState(
39+
task_description=f"When you are asked for your name, respond with '{answer}' (without quotes).",
40+
messages=[
41+
Message(role="user", content="What is 2 + 2?"),
42+
Message(role="assistant", content="4"),
43+
Message(role="user", content="What is your name?"),
44+
],
45+
)
46+
47+
solver_res = solver(task_state=task_state)
48+
assert solver_res.output == answer, f"Expected '{answer}', but got {solver_res.output}"
49+
50+
51+
def test_message_format():
52+
"""
53+
Test that messages in our evals format is correctly converted to the format
54+
expected by Gemini.
55+
"""
56+
57+
messages = [
58+
Message(role="system", content="You are a great mathematician."),
59+
Message(role="user", content="What is 2 + 2?"),
60+
Message(role="assistant", content="5"),
61+
Message(role="user", content="That's incorrect. What is 2 + 2?"),
62+
]
63+
64+
gmessages = GeminiSolver._convert_msgs_to_google_format(messages)
65+
expected = [
66+
GoogleMessage(role="user", parts=["You are a great mathematician.\nWhat is 2 + 2?"]),
67+
GoogleMessage(role="model", parts=["5"]),
68+
GoogleMessage(role="user", parts=["That's incorrect. What is 2 + 2?"]),
69+
]
70+
71+
assert gmessages == expected, f"Expected {expected}, but got {gmessages}"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
google-generativeai

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ dependencies = [
4646
"gymnasium",
4747
"networkx",
4848
"chess",
49-
"anthropic"
49+
"anthropic",
50+
"google-generativeai",
5051
]
5152

5253
[project.urls]

0 commit comments

Comments
 (0)