Skip to content

Commit 10b550c

Browse files
committed
more
1 parent 812654e commit 10b550c

File tree

5 files changed

+163
-7
lines changed

5 files changed

+163
-7
lines changed

evals/completion_fns/uni_finder.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo
6767
files = {'file': open(kwargs["file_name"], 'rb')}
6868
data = {
6969
# 'pdf_parse_mode': self.pdf_parse_mode,
70-
'api_key': self.api_key
70+
'api_key': self.api_key,
71+
'model_engine': 'gpt',
7172
}
7273
response = requests.post(url, data=data, files=files)
7374
pdf_id = response.json()['pdf_token'] # 获得pdf的id,表示上传成功,后续可以使用这个id来指定pdf
@@ -88,10 +89,11 @@ def __call__(self, prompt: Union[str, list[dict]], **kwargs: Any) -> UniFinderCo
8889
prompt = CompletionPrompt(prompt).to_formatted_prompt()
8990

9091
payload = {
91-
"model_engine": self.model,
92+
# "model_engine": self.model,
9293
"pdf_token": pdf_token,
9394
"query": prompt,
94-
'api_key': self.api_key
95+
'api_key': self.api_key,
96+
'model_engine': 'gpt',
9597
}
9698
response = requests.post(url, json=payload, timeout=300)
9799
try:

evals/elsuite/choice_match.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Any
4+
5+
import oss2
6+
from oss2.credentials import EnvironmentVariableCredentialsProvider
7+
8+
import evals
9+
import evals.metrics
10+
from evals.api import CompletionFn
11+
from evals.prompt.base import is_chat_prompt
12+
13+
14+
def init_oss():
15+
"""
16+
Initialize OSS client.
17+
"""
18+
# Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables.
19+
auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())
20+
21+
# 设置 Endpoint
22+
endpoint = 'https://oss-cn-beijing.aliyuncs.com'
23+
24+
# 设置 Bucket
25+
bucket_name = 'dp-filetrans-bj'
26+
bucket = oss2.Bucket(auth, endpoint, bucket_name)
27+
28+
return bucket
29+
30+
31+
def get_rag_dataset(samples_jsonl: str) -> list[dict]:
32+
bucket = init_oss()
33+
raw_samples = evals.get_jsonl(samples_jsonl)
34+
35+
for raw_sample in raw_samples:
36+
for ftype in ["", "answer"]:
37+
if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample:
38+
continue
39+
if f"{ftype}file_name" in raw_sample:
40+
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"])
41+
raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file
42+
43+
exists = bucket.object_exists(oss_file)
44+
if exists:
45+
print(f"文件 {oss_file} 已存在于 OSS 中。")
46+
else:
47+
# 上传文件
48+
bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"])
49+
print(f"文件 {oss_file} 已上传到 OSS。")
50+
if f"{ftype}file_link" in raw_sample:
51+
local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \
52+
os.path.basename(raw_sample[f"{ftype}file_link"])
53+
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"])
54+
if not os.path.exists(local_file):
55+
if bucket.object_exists(oss_file):
56+
# 从 OSS 下载文件
57+
Path(local_file).parent.mkdir(parents=True, exist_ok=True)
58+
bucket.get_object_to_file(oss_file, local_file)
59+
print(f"文件 {oss_file} 已下载到本地。")
60+
return raw_samples
61+
62+
63+
class RAGMatch(evals.Eval):
64+
def __init__(
65+
self,
66+
completion_fns: list[CompletionFn],
67+
samples_jsonl: str,
68+
*args,
69+
max_tokens: int = 500,
70+
num_few_shot: int = 0,
71+
few_shot_jsonl: str = None,
72+
**kwargs,
73+
):
74+
super().__init__(completion_fns, *args, **kwargs)
75+
assert len(completion_fns) == 1, "Match only supports one completion fn"
76+
self.max_tokens = max_tokens
77+
self.samples_jsonl = samples_jsonl
78+
self.num_few_shot = num_few_shot
79+
if self.num_few_shot > 0:
80+
assert few_shot_jsonl is not None, "few shot requires few shot sample dataset"
81+
self.few_shot_jsonl = few_shot_jsonl
82+
self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl))
83+
self.choice = ["a)", "b)", "c)", "d)"]
84+
85+
def eval_sample(self, sample: Any, *_):
86+
assert isinstance(sample, dict), "sample must be a dict"
87+
assert "input" in sample, "sample must have an 'input' key"
88+
assert "ideal" in sample, "sample must have an 'ideal' key"
89+
assert isinstance(sample["ideal"], str) or isinstance(
90+
sample["ideal"], list
91+
), "sample['ideal'] must be a string or list of strings"
92+
93+
prompt = sample["input"]
94+
if self.num_few_shot > 0:
95+
assert is_chat_prompt(sample["input"]), "few shot requires chat prompt"
96+
prompt = sample["input"][:-1]
97+
for s in self.few_shot[: self.num_few_shot]:
98+
prompt += s["sample"]
99+
prompt += sample["input"][-1:]
100+
101+
result = self.completion_fn(
102+
prompt=prompt,
103+
temperature=0.0,
104+
**{k: v for k, v in sample.items() if k not in ["input", "ideal"]}
105+
)
106+
sampled = result.get_completions()[0]
107+
108+
extras = {}
109+
if hasattr(result, "extras"):
110+
if "extracted_answer" in result.extras:
111+
sampled = result.extras["extracted_answer"].rstrip(".")
112+
extras = result.extras
113+
114+
sampled_tmp = sampled.split("\n")[-1]
115+
choice = sample["ideal"][:2]
116+
if choice in sampled_tmp:
117+
for i in self.choice:
118+
if i == choice:
119+
continue
120+
elif i in sampled_tmp:
121+
sampled = ""
122+
break
123+
sampled = sample["ideal"]
124+
else:
125+
sampled = ""
126+
127+
return_result = evals.record_and_check_match(
128+
prompt=prompt,
129+
sampled=sampled,
130+
expected=sample["ideal"],
131+
file_name=sample["file_name"],
132+
**extras
133+
)
134+
print("checkresult----------------------------------")
135+
print("sampled", sampled)
136+
print("ideal", sample["ideal"])
137+
print("check result", return_result)
138+
print("end----------------------------------")
139+
return return_result
140+
141+
def run(self, recorder):
142+
samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
143+
self.eval_all_samples(recorder, samples)
144+
events = recorder.get_events("match")
145+
return {
146+
"accuracy": evals.metrics.get_accuracy(events),
147+
"boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events),
148+
}

evals/elsuite/rag_match_fuzzy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def eval_sample(self, sample: Any, *_):
109109
print(sampled)
110110
try:
111111
# pattern = re.compile(r'\w\)[\s\d+]?\s?[°]?[CK]?')
112-
pattern = re.compile(r'\w\)\s\d+(?:\.\d+)?\s?[°]?[CK]?') # 包含整数小数
112+
pattern = re.compile(r'\w\)\s\d+(?:\.\d+)?(?:\s?:\s?\d+(?:\.\d+)?)?\s?[°]?[CK]?') # 包含整数小数比例
113113

114114
sampled0 = pattern.findall(sampled)
115115
if sampled0 is None or sampled0==[]:

evals/registry/completion_fns/uni_finder.yaml

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,10 @@ uni_finder-v1.26:
77
uni_finder-v1.26-cot:
88
class: evals.completion_fns.cot:ChainOfThoughtCompletionFn
99
args:
10-
cot_completion_fn: uni_finder-v1.26
10+
cot_completion_fn: uni_finder-v1.26
11+
12+
13+
uni_finder-v3.07:
14+
class: evals.completion_fns.uni_finder:UniFinderCompletionFn
15+
args:
16+
pdf_parse_mode: v3.07
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11

2-
drug_ChartQA:
2+
drugchart:
33
id: drug_ChartQA.dev.v0
44
metrics: [accuracy]
55

66
drug_ChartQA.dev.v0:
7-
class: evals.elsuite.rag_match:RAGMatch
7+
class: evals.elsuite.choice_match:RAGMatch
88
args:
99
samples_jsonl: drug_ChartQA/samples.jsonl

0 commit comments

Comments
 (0)