Skip to content

Commit 32ef94d

Browse files
committed
Merge branch 'main' into wangcx
2 parents 10b550c + fd66819 commit 32ef94d

File tree

18 files changed

+572
-75
lines changed

18 files changed

+572
-75
lines changed

evals/completion_fns/gemini.py

+46-5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,42 @@ def get_completions(self) -> list[str]:
3333
return [self.response.strip()]
3434

3535

36+
def truncate_multimodal_prompt(prompt_list, max_images=16, max_size_bytes=4 * 1024 * 1024):
37+
"""
38+
Truncates a list of texts and images to meet the constraints of maximum images and total size in bytes.
39+
40+
Parameters:
41+
- prompt_list: List containing texts and images. Images are expected to be dictionaries with keys 'mime_type' and 'data'.
42+
- max_images: Maximum number of images allowed.
43+
- max_size_bytes: Maximum total size allowed in bytes.
44+
45+
Returns:
46+
- A truncated list that fits the constraints.
47+
"""
48+
truncated_list = []
49+
total_size = 0
50+
image_count = 0
51+
52+
for item in prompt_list:
53+
if isinstance(item, str): # It's text
54+
item_size = len(item.encode('utf-8')) # Size in bytes
55+
elif isinstance(item, dict) and item.get('mime_type') and item.get('data'): # It's an image
56+
# The image data is a string representation of bytes; calculate its length accordingly.
57+
item_size = len(item['data']) # Approximation of size in bytes
58+
image_count += 1
59+
else:
60+
continue # Skip any item that doesn't fit expected structure
61+
62+
# Check if adding this item would exceed limits
63+
if total_size + item_size > max_size_bytes or image_count > max_images:
64+
break # Stop adding items
65+
66+
total_size += item_size
67+
truncated_list.append(item)
68+
69+
return truncated_list
70+
71+
3672
class GeminiCompletionFn(CompletionFn):
3773
def __init__(
3874
self,
@@ -71,7 +107,7 @@ def __call__(
71107
if "file_name" in kwargs:
72108
max_tokens = model_max_tokens.get(self.model, 1000000)
73109
attached_file_content = ["The file is as follows:"]
74-
110+
75111
if self.model == "gemini-pro-vision":
76112
attached_file_content += extract_text_and_fill_in_images(kwargs["file_name"], None, False)
77113
content_types = [type(c) for c in attached_file_content]
@@ -81,11 +117,16 @@ def __call__(
81117
attached_file_content = ["The file is as follows:"] + ["".join(extract_text(kwargs["file_name"]))]
82118
else:
83119
attached_file_content += ["".join(extract_text(kwargs["file_name"]))]
120+
121+
contents = [openai_create_prompt] + attached_file_content
122+
84123
if self.model == "gemini-pro":
85-
while num_tokens_from_string(attached_file_content[1], "cl100k_base") > max_tokens:
86-
attached_file_content[1] = attached_file_content[1][:-1000]
124+
while num_tokens_from_string(contents[2], "cl100k_base") > max_tokens:
125+
contents[2] = contents[2][:-1000]
126+
elif self.model == "gemini-pro-vision":
127+
contents = truncate_multimodal_prompt(contents, max_images=16, max_size_bytes=4 * 1024 * 1024)
87128
else:
88-
attached_file_content = []
129+
contents = [openai_create_prompt]
89130
self.model = "gemini-pro"
90131

91132
genai.configure(api_key=np.random.choice(self.api_keys))
@@ -119,7 +160,7 @@ def __call__(
119160
safety_settings=safety_settings)
120161
# response = request_with_timeout(model.generate_content, contents=[openai_create_prompt] + attached_file_content)
121162
response = model.generate_content(
122-
contents=[openai_create_prompt] + attached_file_content,
163+
contents=contents,
123164
)
124165
# answer = response.text
125166

evals/completion_fns/uni_finder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo
6666
url = f"{self.api_base}/api/external/upload_pdf"
6767
files = {'file': open(kwargs["file_name"], 'rb')}
6868
data = {
69-
# 'pdf_parse_mode': self.pdf_parse_mode,
69+
'pdf_parse_mode': "fast",
7070
'api_key': self.api_key,
7171
'model_engine': 'gpt',
7272
}
@@ -89,7 +89,7 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo
8989
prompt = CompletionPrompt(prompt).to_formatted_prompt()
9090

9191
payload = {
92-
# "model_engine": self.model,
92+
"model_engine": "gpt",
9393
"pdf_token": pdf_token,
9494
"query": prompt,
9595
'api_key': self.api_key,

evals/elsuite/rag_reaction_extract.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import json
2+
import os
3+
import re
4+
import traceback
5+
import uuid
6+
from pathlib import Path
7+
from typing import Optional
8+
9+
import numpy as np
10+
from pydantic import BaseModel
11+
12+
import evals.metrics
13+
from evals.api import CompletionFn
14+
from evals.elsuite.rag_match import get_rag_dataset
15+
from evals.elsuite.utils import ReactionDictMatching, ReactionDictMatchingSimple
16+
from evals.record import RecorderBase, record_match
17+
18+
code_pattern = r"```[\s\S]*?\n([\s\S]+?)\n```"
19+
json_pattern = r"```json[\s\S]*?\n([\s\S]+?)\n```"
20+
csv_pattern = r"```csv[\s\S]*?\n([\s\S]+?)\n```"
21+
table_pattern = r"\n({index0}[\s\S]+)\n[`]*"
22+
outlink_pattern = r"\[Download[a-zA-Z0-9 ]+?\]\((https://[a-zA-Z0-9_. /]+?)\)"
23+
24+
25+
class FileSampleWithInput(BaseModel):
26+
input: Optional[str]
27+
file_name: Optional[str]
28+
file_link: Optional[str]
29+
answerfile_name: Optional[str]
30+
answerfile_link: Optional[str]
31+
32+
33+
class ReactionExtract(evals.Eval):
34+
def __init__(
35+
self,
36+
completion_fns: list[CompletionFn],
37+
samples_jsonl: str,
38+
*args,
39+
instructions: Optional[str] = "",
40+
**kwargs,
41+
):
42+
super().__init__(completion_fns, *args, **kwargs)
43+
assert len(completion_fns) < 3, "ReactionExtract only supports 3 completion fns"
44+
self.samples_jsonl = samples_jsonl
45+
self.instructions = instructions
46+
47+
def eval_sample(self, sample, rng):
48+
assert isinstance(sample, dict)
49+
50+
input_formatted = sample["input"] if type(sample["input"]) == list else [{"role": "user", "content": sample["input"]}]
51+
if self.instructions:
52+
prompt = [{"role": "system", "content": self.instructions}] + input_formatted
53+
else:
54+
prompt = input_formatted
55+
56+
result = self.completion_fn(
57+
prompt=prompt,
58+
temperature=0.0,
59+
file_name=sample["file_name"],
60+
file_link=sample["file_link"]
61+
)
62+
sampled = result.get_completions()[0]
63+
# correct_str = open(sample["answerfile_name"], 'r').read()
64+
# correct_answer = json.loads(correct_str)
65+
correct_answer = json.load(open(sample["answerfile_name"], 'r'))["inputs"]
66+
correct_str = json.dumps(correct_answer, indent=4)
67+
68+
try:
69+
if re.search(outlink_pattern, sampled) is not None:
70+
code = re.search(outlink_pattern, sampled).group()
71+
link = re.sub(outlink_pattern, r"\1", code)
72+
73+
fname = f"/tmp/LLMEvals_{uuid.uuid4()}.json"
74+
os.system(f"wget {link} -O {fname}")
75+
answer = json.load(open(fname, 'r'))
76+
elif "json" in self.instructions:
77+
code = re.search(json_pattern, sampled).group()
78+
code_content = re.sub(json_pattern, r"\1", code)
79+
code_content = code_content.replace("\"", '"')
80+
81+
# Delete comments
82+
code_content = re.sub(r'//.*', '', code_content)
83+
answer = json.loads(code_content)
84+
else:
85+
answer = {}
86+
picked_str = json.dumps(answer, indent=4)
87+
open(sample["answerfile_name"].replace(".json", "_out.json"), 'w').write(picked_str)
88+
except:
89+
print(Path(sample["file_name"]).stem)
90+
traceback.print_exc()
91+
record_match(
92+
prompt=prompt,
93+
correct=False,
94+
expected=correct_str,
95+
picked=sampled,
96+
file_name=sample["file_name"],
97+
jobtype="match_all"
98+
)
99+
picked_str = "Failed to parse"
100+
answer = {}
101+
return {"accuracy_leaves": 0}
102+
103+
accuracy_leaves, df = ReactionDictMatchingSimple(correct_answer, answer, content="raw")
104+
record_match(
105+
prompt=prompt,
106+
correct=(accuracy_leaves == 1.0),
107+
expected=correct_str,
108+
picked=picked_str,
109+
file_name=sample["file_name"],
110+
jobtype="match_all"
111+
)
112+
return {"accuracy_leaves": accuracy_leaves}
113+
114+
def run(self, recorder: RecorderBase):
115+
samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
116+
metrics_all_sample = self.eval_all_samples(recorder, samples)
117+
118+
metrics = {key: np.mean([sample_metrics[key] for sample_metrics in metrics_all_sample]) for key in metrics_all_sample[0].keys()}
119+
# if "SMILES" in raw_samples[0]["compare_fields"]:
120+
# metrics["recall_SMILES"] = np.mean([sample_metrics["recall_SMILES"] for sample_metrics in metrics_all_sample
121+
# if "recall_SMILES" in sample_metrics])
122+
return metrics

evals/elsuite/rag_table_extract.py

+2
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def eval_sample(self, sample, rng):
238238

239239
if table.shape[0] != 0:
240240
idxlist = table.columns
241+
if type(sample.index) == str and table.columns.nlevels > 1:
242+
sample.index = tuple([sample.index] + ["" for _ in range(table.columns.nlevels - 1)])
241243
if type(sample.index) in [str, tuple]:
242244
if sample.index not in table.columns:
243245
idxlist = [sample.index] + list(table.columns)[1:]

0 commit comments

Comments
 (0)