Skip to content

Commit ae76148

Browse files
committed
2 parents b145b96 + 733291c commit ae76148

14 files changed

+191
-8
lines changed

Diff for: evals/elsuite/choice_match.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Any
4+
5+
import oss2
6+
from oss2.credentials import EnvironmentVariableCredentialsProvider
7+
8+
import evals
9+
import evals.metrics
10+
from evals.api import CompletionFn
11+
from evals.prompt.base import is_chat_prompt
12+
13+
14+
def init_oss():
15+
"""
16+
Initialize OSS client.
17+
"""
18+
# Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables.
19+
auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())
20+
21+
# 设置 Endpoint
22+
endpoint = 'https://oss-cn-beijing.aliyuncs.com'
23+
24+
# 设置 Bucket
25+
bucket_name = 'dp-filetrans-bj'
26+
bucket = oss2.Bucket(auth, endpoint, bucket_name)
27+
28+
return bucket
29+
30+
31+
def get_rag_dataset(samples_jsonl: str) -> list[dict]:
32+
bucket = init_oss()
33+
raw_samples = evals.get_jsonl(samples_jsonl)
34+
35+
for raw_sample in raw_samples:
36+
for ftype in ["", "answer"]:
37+
if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample:
38+
continue
39+
if f"{ftype}file_name" in raw_sample:
40+
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"])
41+
raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file
42+
43+
exists = bucket.object_exists(oss_file)
44+
if exists:
45+
print(f"文件 {oss_file} 已存在于 OSS 中。")
46+
else:
47+
# 上传文件
48+
bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"])
49+
print(f"文件 {oss_file} 已上传到 OSS。")
50+
if f"{ftype}file_link" in raw_sample:
51+
local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \
52+
os.path.basename(raw_sample[f"{ftype}file_link"])
53+
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"])
54+
if not os.path.exists(local_file):
55+
if bucket.object_exists(oss_file):
56+
# 从 OSS 下载文件
57+
Path(local_file).parent.mkdir(parents=True, exist_ok=True)
58+
bucket.get_object_to_file(oss_file, local_file)
59+
print(f"文件 {oss_file} 已下载到本地。")
60+
return raw_samples
61+
62+
63+
class RAGMatch(evals.Eval):
64+
def __init__(
65+
self,
66+
completion_fns: list[CompletionFn],
67+
samples_jsonl: str,
68+
*args,
69+
max_tokens: int = 500,
70+
num_few_shot: int = 0,
71+
few_shot_jsonl: str = None,
72+
**kwargs,
73+
):
74+
super().__init__(completion_fns, *args, **kwargs)
75+
assert len(completion_fns) == 1, "Match only supports one completion fn"
76+
self.max_tokens = max_tokens
77+
self.samples_jsonl = samples_jsonl
78+
self.num_few_shot = num_few_shot
79+
if self.num_few_shot > 0:
80+
assert few_shot_jsonl is not None, "few shot requires few shot sample dataset"
81+
self.few_shot_jsonl = few_shot_jsonl
82+
self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl))
83+
84+
def eval_sample(self, sample: Any, *_):
85+
assert isinstance(sample, dict), "sample must be a dict"
86+
assert "input" in sample, "sample must have an 'input' key"
87+
assert "ideal" in sample, "sample must have an 'ideal' key"
88+
assert isinstance(sample["ideal"], str) or isinstance(
89+
sample["ideal"], list
90+
), "sample['ideal'] must be a string or list of strings"
91+
92+
prompt = sample["input"]
93+
if self.num_few_shot > 0:
94+
assert is_chat_prompt(sample["input"]), "few shot requires chat prompt"
95+
prompt = sample["input"][:-1]
96+
for s in self.few_shot[: self.num_few_shot]:
97+
prompt += s["sample"]
98+
prompt += sample["input"][-1:]
99+
100+
result = self.completion_fn(
101+
prompt=prompt,
102+
temperature=0.0,
103+
**{k: v for k, v in sample.items() if k not in ["input", "ideal"]}
104+
)
105+
sampled = result.get_completions()[0]
106+
107+
extras = {}
108+
if hasattr(result, "extras"):
109+
if "extracted_answer" in result.extras:
110+
sampled = result.extras["extracted_answer"].rstrip(".")
111+
extras = result.extras
112+
print(sampled)
113+
sampled = sampled.split("\n")
114+
for i in range(len(sampled)-1, -1, -1):
115+
if i == 0:
116+
sampled = sampled[0]
117+
elif sampled[i] == "":
118+
continue
119+
else:
120+
sampled = sampled[i]
121+
break
122+
for i in ["a)", "b)", "c)", "d)"]:
123+
if i in sample["ideal"] and i in sampled:
124+
continue
125+
elif i not in sample["ideal"] and i not in sampled:
126+
continue
127+
else:
128+
sampled = ""
129+
break
130+
if sampled != "":
131+
sampled = sample["ideal"]
132+
print("compare", sampled, sample["ideal"])
133+
return evals.record_and_check_match(
134+
prompt=prompt,
135+
sampled=sampled,
136+
expected=sample["ideal"],
137+
file_name=sample["file_name"],
138+
**extras
139+
)
140+
141+
def run(self, recorder):
142+
samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
143+
self.eval_all_samples(recorder, samples)
144+
events = recorder.get_events("match")
145+
return {
146+
"accuracy": evals.metrics.get_accuracy(events),
147+
"boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events),
148+
}

Diff for: evals/elsuite/rag_table_extract_comp.py

-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ def eval_sample(self, sample, rng):
310310
result = self.completion_fn(
311311
prompt=prompt,
312312
temperature=0.0,
313-
max_tokens=5,
314313
file_name=sample.file_name,
315314
file_link=sample.file_link
316315
)

Diff for: evals/elsuite/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def tableMatching(df_ref, df_prompt, index='Compound', compare_fields=[], record
367367
return {"recall_field": 0.0, "recall_index": 0.0, "recall_value": 0.0, "recall_value_strict": 0.0,
368368
"accuracy_value": 0.0, "accuracy_value_strict": 0.0, "recall_SMILES": 0.0}
369369
metrics = {}
370-
index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate"]
370+
index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate","AlloyName"]
371371

372372
if index not in [None, ""]:
373373
df_ref[index] = df_ref[index].astype(str)

Diff for: evals/registry/data/01_alloychart/samples.jsonl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:805f5c92e67bbb8a143fd30368bd0a480486d8733a1e1c1bed8985b2ad007bde
3+
size 7783
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:9b2190d4d26fc45c5ba0e6615f3c10964fd7a547be434c58e548f91eb2e412c5
3-
size 33632
2+
oid sha256:5a4facecc1fe28bd1d84fae1273eecdaa03c2eac65e4a66a3f348f8ddff671b6
3+
size 33206

Diff for: evals/registry/data/01_alloynum/alloy_number.jsonl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:70ae2613eb223f2a9b7e8cbf10f81b5e626a9ae6827d94c7e3e25df1b8e4e4bf
3-
size 36088
2+
oid sha256:2ca8af0565b74d457d188e00a5f25ff574b5d40175bc2c6321523d47753907f9
3+
size 35391

Diff for: evals/registry/data/01_alloysort/sort.jsonl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:45b0bf49f634caf51353105a1d62a5cf8285654b0ba8a3e744442b879400bfd0
3-
size 15245
2+
oid sha256:d70ec2f034dc30585433fd1eff176bf5d72035d6a421062d0dc834a8d434a458
3+
size 15876

Diff for: evals/registry/data/05_biochart/samples.jsonl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:f17bc79d6a2f3de2d03c1682462fb1453a654cbc017f42d1edbe96689aca9895
3+
size 7500

Diff for: evals/registry/data/05_biochart/samples_single.jsonl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:fe96f94f5279ca67003295191c7f2c2002976a49bbe4856e58d0799bb7f58f2c
3+
size 7664
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:1093a3a82f10c6ac2ce9f9896cc9159bbea4380c3d4bf9e4977224ab02f4df52
3+
size 7636

Diff for: evals/registry/data/drug_ChartQA/samples copy.jsonl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:3f5cf6d65edf256002bdd6d4bf8fdae23698a6150b23bf778e246fa9c51b8074
3+
size 8787

Diff for: evals/registry/data/drug_ChartQA/samples.jsonl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:3f5cf6d65edf256002bdd6d4bf8fdae23698a6150b23bf778e246fa9c51b8074
3+
size 8787

Diff for: evals/registry/evals/01_scipaper_alloychart.yaml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
alloychart:
3+
id: alloychart.dev.v0
4+
metrics: [accuracy]
5+
6+
alloychart.dev.v0:
7+
class: evals.elsuite.rag_match_fuzzy:RAGMatch
8+
args:
9+
samples_jsonl: 01_alloychart/samples.jsonl

Diff for: evals/registry/evals/01_scipaper_drugchart.yaml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
drugchart:
3+
id: drug_ChartQA.dev.v0
4+
metrics: [accuracy]
5+
6+
drug_ChartQA.dev.v0:
7+
class: evals.elsuite.choice_match:RAGMatch
8+
args:
9+
samples_jsonl: drug_ChartQA/samples.jsonl

0 commit comments

Comments
 (0)