Skip to content

Commit 7f01360

Browse files
committed
Make RAGMatch a general pipeline, and add molecule similarity
1 parent dfc78d5 commit 7f01360

File tree

7 files changed

+214
-97
lines changed

7 files changed

+214
-97
lines changed

evals/elsuite/basic/match.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,17 @@ def eval_sample(self, sample: Any, *_):
4949
)
5050
sampled = result.get_completions()[0]
5151

52+
extras = {}
53+
if hasattr(result, "extras"):
54+
if "extracted_answer" in result.extras:
55+
sampled = result.extras["extracted_answer"].rstrip(".")
56+
extras = result.extras
57+
5258
return evals.record_and_check_match(
5359
prompt=prompt,
5460
sampled=sampled,
5561
expected=sample["ideal"],
62+
**extras
5663
)
5764

5865
def run(self, recorder):

evals/elsuite/rag_match.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import evals.metrics
1010
from evals.api import CompletionFn
1111
from evals.prompt.base import is_chat_prompt
12+
from evals.utils.misc import make_object
1213

1314

1415
def init_oss():
@@ -69,6 +70,9 @@ def __init__(
6970
max_tokens: int = 500,
7071
num_few_shot: int = 0,
7172
few_shot_jsonl: str = None,
73+
func_postprocess_answer: str = None,
74+
func_comparison: str = None,
75+
record_match_threshold: float = -1,
7276
**kwargs,
7377
):
7478
super().__init__(completion_fns, *args, **kwargs)
@@ -81,6 +85,10 @@ def __init__(
8185
self.few_shot_jsonl = few_shot_jsonl
8286
self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl))
8387

88+
self.func_postprocess_answer = make_object(func_postprocess_answer) if func_postprocess_answer else None
89+
self.func_comparison = make_object(func_comparison) if func_comparison else None
90+
self.record_match_threshold = record_match_threshold
91+
8492
def eval_sample(self, sample: Any, *_):
8593
assert isinstance(sample, dict), "sample must be a dict"
8694
assert "input" in sample, "sample must have an 'input' key"
@@ -102,27 +110,61 @@ def eval_sample(self, sample: Any, *_):
102110
temperature=0.0,
103111
**{k: v for k, v in sample.items() if k not in ["input", "ideal"]}
104112
)
105-
sampled = result.get_completions()[0]
113+
sampled = result.get_completions()[0].strip()
106114

107-
extras = {}
115+
extras = {"file_name": sample["file_name"], "file_link": sample["file_link"]} if "file_name" in sample else {}
108116
if hasattr(result, "extras"):
109117
if "extracted_answer" in result.extras:
110118
sampled = result.extras["extracted_answer"].rstrip(".")
111119
extras = result.extras
112-
113-
return evals.record_and_check_match(
114-
prompt=prompt,
115-
sampled=sampled,
116-
expected=sample["ideal"],
117-
file_name=sample["file_name"],
118-
**extras
119-
)
120+
else:
121+
extras["answer"] = sampled
122+
123+
if self.func_postprocess_answer:
124+
extras["answer"] = sampled
125+
sampled = extras["extracted_answer"] = self.func_postprocess_answer(sampled)
126+
127+
if self.func_comparison:
128+
metrics = self.func_comparison(sampled, sample["ideal"][0])
129+
if type(metrics) == bool:
130+
evals.record.record_match(correct=metrics,
131+
expected=sample["ideal"],
132+
picked=sampled, sampled=extras["answer"],
133+
prompt=prompt,
134+
**extras)
135+
else:
136+
evals.record.record_metrics(**metrics)
137+
if self.record_match_threshold > 0:
138+
evals.record.record_match(correct=metrics["score"] >= self.record_match_threshold,
139+
**metrics,
140+
expected=sample["ideal"],
141+
picked=sampled, sampled=extras["answer"],
142+
prompt=prompt,
143+
**extras)
144+
else:
145+
return evals.record_and_check_match(
146+
prompt=prompt,
147+
sampled=sampled,
148+
expected=sample["ideal"],
149+
**extras
150+
)
120151

121152
def run(self, recorder):
122153
samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
123154
self.eval_all_samples(recorder, samples)
155+
124156
events = recorder.get_events("match")
125-
return {
126-
"accuracy": evals.metrics.get_accuracy(events),
127-
"boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events),
128-
}
157+
if len(events) > 0:
158+
record_metrics = {
159+
"accuracy": evals.metrics.get_accuracy(events),
160+
"bootstrap_std": evals.metrics.get_bootstrap_accuracy_std(events),
161+
}
162+
else:
163+
record_metrics = {}
164+
165+
all_sample_metrics = recorder.get_metrics()
166+
scores = [m["score"] for m in all_sample_metrics if m.get("score") is not None]
167+
if scores:
168+
record_metrics["score"] = sum(scores) / len(scores)
169+
170+
return record_metrics

evals/elsuite/utils.py

Lines changed: 139 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
}
124124

125125

126+
# Highlight: Part 1. Post-Processing Functions for LLM Outputs
127+
126128
def get_answer(text, answer_prompt, ignore_case=False):
127129
if ignore_case:
128130
idx = text.lower().rfind(answer_prompt.lower())
@@ -142,22 +144,6 @@ def get_consensus(answers):
142144
return max(counts, key=counts.get)
143145

144146

145-
def compare_molecule(smi1, smi2) -> bool:
146-
from rdkit import Chem
147-
from rdkit.Chem import AllChem
148-
149-
mol1 = Chem.MolFromSmiles(smi1)
150-
mol2 = Chem.MolFromSmiles(smi2)
151-
if mol1 is None or mol2 is None:
152-
return False
153-
else:
154-
return Chem.MolToSmiles(Chem.RemoveHs(mol1)) == Chem.MolToSmiles(Chem.RemoveHs(mol2))
155-
# return False
156-
# fp1 = AllChem.GetMorganFingerprint(mol1, 2)
157-
# fp2 = AllChem.GetMorganFingerprint(mol2, 2)
158-
# return DataStructs.TanimotoSimilarity(fp1, fp2)
159-
160-
161147
def normalize(s: str) -> str:
162148
"""Lower text and remove punctuation, articles and extra whitespace."""
163149
s = s.lower()
@@ -168,6 +154,87 @@ def normalize(s: str) -> str:
168154
return s
169155

170156

157+
def fuzzy_normalize_name(s):
158+
if s.startswith("Unnamed"):
159+
return ""
160+
else:
161+
""" 标准化字符串 """
162+
# # 定义需要移除的单位和符号
163+
# units = ["µM", "µg/mL", "nM", "%", "wt.%", "at.%", "at%", "wt%"]
164+
# for unit in units:
165+
# s = s.replace(unit, "")
166+
167+
# 定义特定关键字
168+
keywords = ["pIC50", "IC50", "EC50", "TC50", "GI50", "Ki", "Kd", "Kb", "pKb"]
169+
170+
# 移除非字母数字的字符,除了空格
171+
s = re.sub(r'[^\w\s%.\-\(\)]', '', s)
172+
if s in synonyms:
173+
s = synonyms[s]
174+
175+
# 分割字符串为单词列表
176+
words = s.split()
177+
178+
# 将关键字移到末尾
179+
reordered_words = [word for word in words if word not in keywords]
180+
keywords_in_string = [word for word in words if word in keywords]
181+
reordered_words.extend(keywords_in_string)
182+
# 重新组合为字符串
183+
return ' '.join(reordered_words)
184+
185+
186+
def fuzzy_normalize_value(vi):
187+
try:
188+
vi = str(vi).lower()
189+
190+
if "bal" in vi or "remainder" in vi or "bas" in vi:
191+
vi = "bal"
192+
return "bal"
193+
194+
if ("nan" in vi and not "–" in vi) or "/" == vi or "n/a" in vi or "na" in vi or vi == "":
195+
vi = "0"
196+
vi = vi.replace("nan", "–").replace("~", "-")
197+
198+
pattern = r"\d+(?:\.\d+)?"
199+
matches = re.findall(pattern, vi)
200+
if len(matches) == 2:
201+
vi = f"{matches[0]}-{matches[1]}"
202+
elif len(matches) == 1:
203+
vi = matches[0]
204+
205+
if "<" in vi:
206+
vi = vi.replace("<", "")
207+
if ">" in vi:
208+
vi = vi.replace(">", "")
209+
210+
try:
211+
vi = float(vi)
212+
vi = round(vi, 3)
213+
except:
214+
# print(vi)
215+
pass
216+
except:
217+
pass
218+
219+
return vi
220+
221+
222+
def extract_choice_and_value(sampled):
223+
pattern = re.compile(r'\w\)\s\d+(?:\.\d+)?(?:\s?:\s?\d+(?:\.\d+)?)?\s?[°]?[CK]?')
224+
matches = pattern.findall(sampled)
225+
if matches:
226+
sampled0 = pattern.findall(sampled)[0]
227+
else:
228+
return "No answer."
229+
if sampled0 is None or sampled0 == []:
230+
pass
231+
else:
232+
sampled = sampled0.replace("°", " ")
233+
sampled = sampled.replace(" ", " ")
234+
return sampled
235+
236+
# Part 2. Comparison Functions for Post-Processed LLM Outputs
237+
171238
def fuzzy_match(s1: str, s2: str) -> bool:
172239
s1 = normalize(s1)
173240
s2 = normalize(s2)
@@ -264,69 +331,32 @@ def is_float(str):
264331
pass
265332

266333

267-
def fuzzy_normalize_name(s):
268-
if s.startswith("Unnamed"):
269-
return ""
270-
else:
271-
""" 标准化字符串 """
272-
# # 定义需要移除的单位和符号
273-
# units = ["µM", "µg/mL", "nM", "%", "wt.%", "at.%", "at%", "wt%"]
274-
# for unit in units:
275-
# s = s.replace(unit, "")
276-
277-
# 定义特定关键字
278-
keywords = ["pIC50", "IC50", "EC50", "TC50", "GI50", "Ki", "Kd", "Kb", "pKb"]
279-
280-
# 移除非字母数字的字符,除了空格
281-
s = re.sub(r'[^\w\s%.\-\(\)]', '', s)
282-
if s in synonyms:
283-
s = synonyms[s]
284-
285-
# 分割字符串为单词列表
286-
words = s.split()
287-
288-
# 将关键字移到末尾
289-
reordered_words = [word for word in words if word not in keywords]
290-
keywords_in_string = [word for word in words if word in keywords]
291-
reordered_words.extend(keywords_in_string)
292-
# 重新组合为字符串
293-
return ' '.join(reordered_words)
294-
295-
296-
def fuzzy_normalize_value(vi):
297-
try:
298-
vi = str(vi).lower()
299-
300-
if "bal" in vi or "remainder" in vi or "bas" in vi:
301-
vi = "bal"
302-
return "bal"
334+
def compare_molecule_similarity(smi1, smi2) -> dict:
335+
from rdkit import Chem
336+
from rdkit.Chem import AllChem
337+
from rdkit import DataStructs
303338

304-
if ("nan" in vi and not "–" in vi) or "/" == vi or "n/a" in vi or "na" in vi or vi == "":
305-
vi = "0"
306-
vi = vi.replace("nan", "–").replace("~", "-")
339+
mol1 = Chem.MolFromSmiles(re.sub(r'<.*>', '', str(smi1).strip("`")))
340+
mol2 = Chem.MolFromSmiles(re.sub(r'<.*>', '', str(smi2).strip("`")))
307341

308-
pattern = r"\d+(?:\.\d+)?"
309-
matches = re.findall(pattern, vi)
310-
if len(matches) == 2:
311-
vi = f"{matches[0]}-{matches[1]}"
312-
elif len(matches) == 1:
313-
vi = matches[0]
342+
if mol1 is None or mol2 is None:
343+
sim = 0.0
344+
else:
345+
fp1 = AllChem.GetMorganFingerprint(mol1, 2)
346+
fp2 = AllChem.GetMorganFingerprint(mol2, 2)
347+
sim = DataStructs.TanimotoSimilarity(fp1, fp2)
348+
return {"score": sim}
314349

315-
if "<" in vi:
316-
vi = vi.replace("<", "")
317-
if ">" in vi:
318-
vi = vi.replace(">", "")
319350

320-
try:
321-
vi = float(vi)
322-
vi = round(vi, 3)
323-
except:
324-
# print(vi)
325-
pass
326-
except:
327-
pass
351+
def compare_molecule_strict(smi1, smi2) -> bool:
352+
from rdkit import Chem
328353

329-
return vi
354+
mol1 = Chem.MolFromSmiles(smi1)
355+
mol2 = Chem.MolFromSmiles(smi2)
356+
if mol1 is None or mol2 is None:
357+
return False
358+
else:
359+
return Chem.MolToSmiles(Chem.RemoveHs(mol1)) == Chem.MolToSmiles(Chem.RemoveHs(mol2))
330360

331361

332362
def tableMatching(df_ref, df_prompt, index='Compound', compare_fields=[], record=True, file_name=None):
@@ -350,7 +380,7 @@ def match_indices(ind0, ind1, threshold=0.9) -> dict:
350380
Match the indices of two dataframes.
351381
"""
352382
renames = {}
353-
name2query = lambda name: name if type(name) != tuple else name[0] if len(name) == 1 or name[1] == "" else name[1]
383+
name2query = lambda name: name if type(name) != tuple else name[0] if len(name) == 1 or name[-1] == "" else name[-1]
354384
similarities = np.array(np.ones([len(ind0) + 15, len(ind1) + 15]), dtype=np.float64)
355385
querys0 = [name2query(name) for name in ind0]
356386
querys1 = [name2query(name) for name in ind1]
@@ -434,7 +464,7 @@ def match_indices(ind0, ind1, threshold=0.9) -> dict:
434464
except:
435465
p = 'not found'
436466

437-
_is_matching = fuzzy_compare_name(gt, p, compare_value=True) if col != "SMILES" else compare_molecule(gt, p)
467+
_is_matching = fuzzy_compare_name(gt, p, compare_value=True) if col != "SMILES" else compare_molecule_strict(gt, p)
438468
if col == "SMILES":
439469
smiles_match_score += float(_is_matching)
440470
if record:
@@ -558,6 +588,38 @@ def count_leaves(d, count=0):
558588
return 0
559589
ratio = total_diff_leaves / total_leaves_dict1
560590

591+
if total_diff_leaves == total_leaves_dict1 and len(list(dict_ref.keys())) == len(list(dict_prompt.keys())):
592+
values1 = list(dict_ref.values())
593+
values2 = list(dict_prompt.values())
594+
595+
# Initialize containers for differences
596+
differences = []
597+
598+
# The maximum length to iterate over
599+
max_length = max(len(values1), len(values2))
600+
601+
total_diff_leaves = 0
602+
603+
for i in range(max_length):
604+
try:
605+
value1 = values1[i]
606+
value2 = values2[i]
607+
except IndexError:
608+
# Handle cases where the lists have different lengths
609+
differences.append('Different number of elements.')
610+
break
611+
612+
# If both values are dictionaries, use DeepDiff to compare them deeply
613+
if isinstance(value1, dict) and isinstance(value2, dict):
614+
diff = DeepDiff(value1, value2, ignore_order=True, report_repetition=True)
615+
if diff:
616+
total_diff_leaves += sum(len(diff.get(key, {})) for key in diff_keys)
617+
differences.append(diff)
618+
elif value1 != value2:
619+
total_diff_leaves += 1
620+
# For non-dictionary values, just compare them directly
621+
differences.append({'different_values': (value1, value2)})
622+
561623
return 1.0 - ratio, diff
562624

563625

@@ -863,6 +925,7 @@ def macro_f1_score_3(model, prediction: List[List[Any]], answers: List[List[Any]
863925
except:
864926
return 0.0
865927

928+
866929
def scrub_formatting_from_prompt(prompt):
867930
scrubbed_prompt = copy.copy(prompt)
868931

0 commit comments

Comments
 (0)