Skip to content

Commit b8bdf7a

Browse files
authored
Merge pull request #108 from YuyaoZhangQAQ/feature/rqrag-support
rqrag support
2 parents 3cac6a8 + 6918785 commit b8bdf7a

File tree

2 files changed

+249
-0
lines changed

2 files changed

+249
-0
lines changed

examples/methods/run_exp.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,32 @@ def adaptive(args):
552552
pipeline = AdaptivePipeline(config)
553553
result = pipeline.run(test_data)
554554

555+
def rqrag(args):
556+
"""
557+
Function to run the RQRAGPipeline.
558+
"""
559+
from flashrag.pipeline.active_pipeline import RQRAGPipeline
560+
561+
save_note = "rqrag"
562+
563+
config_dict = {
564+
"save_note": save_note,
565+
"gpu_id": args.gpu_id,
566+
'framework': 'vllm',
567+
"dataset_name": args.dataset_name,
568+
"split": args.split,
569+
"retrieval_topk": 5,
570+
"max_depth": 3
571+
}
572+
573+
config = Config("my_config.yaml", config_dict)
574+
575+
all_split = get_dataset(config)
576+
test_data = all_split[args.split]
577+
578+
pipeline = RQRAGPipeline(config, max_depth = args.max_depth)
579+
result = pipeline.run(test_data)
580+
555581

556582
if __name__ == "__main__":
557583
parser = argparse.ArgumentParser(description="Running exp")
@@ -560,6 +586,7 @@ def adaptive(args):
560586
parser.add_argument("--dataset_name", type=str)
561587
parser.add_argument("--gpu_id", type=str)
562588

589+
563590
func_dict = {
564591
"AAR-contriever": aar,
565592
"AAR-ANCE": aar,
@@ -578,6 +605,7 @@ def adaptive(args):
578605
"ircot": ircot,
579606
"trace": trace,
580607
"adaptive": adaptive,
608+
"rqrag": rqrag,
581609
}
582610

583611
args = parser.parse_args()

flashrag/pipeline/active_pipeline.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from tqdm import tqdm
3+
import math
34
import numpy as np
45
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
56
from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse
@@ -1014,3 +1015,223 @@ def run(self, dataset, do_eval=True, pred_process_fun=ircot_pred_parse):
10141015

10151016
dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
10161017
return dataset
1018+
1019+
class RQRAGPipeline(BasicPipeline):
1020+
expand_on_tokens = [
1021+
"[S_Rewritten_Query]",
1022+
"[S_Decomposed_Query]",
1023+
"[S_Disambiguated_Query]",
1024+
"[A_Response]"
1025+
]
1026+
1027+
system_prompt = {
1028+
"qa": "Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response."
1029+
}
1030+
1031+
response_generation_params = {
1032+
"temperature": 0,
1033+
"top_p": 0.9,
1034+
"stop": ["[EOS]", "</s>"],
1035+
"skip_special_tokens": False,
1036+
"include_stop_str_in_output": True,
1037+
"logprobs": 1,
1038+
"spaces_between_special_tokens": False,
1039+
"max_tokens": 4096
1040+
}
1041+
1042+
other_generation_params = {
1043+
"temperature": 1,
1044+
"top_p": 0.9,
1045+
"stop": ["[EOS]", "</s>"],
1046+
"skip_special_tokens": False,
1047+
"include_stop_str_in_output": True,
1048+
"logprobs": 1,
1049+
"spaces_between_special_tokens": False,
1050+
"max_tokens": 4096
1051+
}
1052+
1053+
from flashrag.dataset import Dataset, Item
1054+
from typing import List, Tuple
1055+
1056+
def __init__(
1057+
self,
1058+
config: dict,
1059+
prompt_template = None,
1060+
retriever = None,
1061+
generator = None,
1062+
max_depth = 3,
1063+
batch_size = 32
1064+
):
1065+
super().__init__(config, prompt_template)
1066+
1067+
1068+
self.generator = generator if generator is not None else get_generator(config)
1069+
self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model_path"], padding_side = "left")
1070+
self.retriever = retriever if retriever is not None else get_retriever(config)
1071+
1072+
self.max_depth = max_depth
1073+
self.batch_size = batch_size
1074+
1075+
# Due to the low effiency of original method, it only supports vllm now.
1076+
1077+
def preprocess_eval_data(self, items: List[Item]) -> List[str]:
1078+
eval_examples = []
1079+
1080+
for item in items:
1081+
eval_example = f"<s><|system|>\n{self.system_prompt['qa']}" + self.tokenizer.eos_token + "\n<|user|>\n" + item.question + self.tokenizer.eos_token + "\n"
1082+
eval_example += "<|assistant|>\n"
1083+
eval_examples.append(eval_example)
1084+
1085+
return eval_examples
1086+
1087+
def format_evidences(self, evidences: List[str]):
1088+
format_evidence = ""
1089+
for evidence in evidences:
1090+
title = evidence['contents'].split('\n')[0]
1091+
text = "\n".join(evidence['contents'].split('\n')[1:])
1092+
format_evidence += f"Title: {title}\n"
1093+
format_evidence += f"Text: {text}\n"
1094+
return format_evidence
1095+
1096+
def generate_tree_of_thoughts_batch(self, initial_prompts_batch: List[str]):
1097+
paths_batch_dict = {
1098+
idx: [{
1099+
"prompt": initial_prompt,
1100+
"depth": 0,
1101+
"done": False
1102+
}]
1103+
for idx, initial_prompt in enumerate(initial_prompts_batch)
1104+
}
1105+
1106+
final_outputs_batch = {idx: [] for idx in range(len(initial_prompts_batch))}
1107+
1108+
while any(paths for paths in paths_batch_dict.values()):
1109+
current_batch = []
1110+
for i, _ in paths_batch_dict.items():
1111+
if paths_batch_dict[i]:
1112+
current_path = paths_batch_dict[i].pop(0)
1113+
current_batch.append(current_path)
1114+
else:
1115+
continue
1116+
1117+
if not current_batch:
1118+
break
1119+
1120+
for special_token in self.expand_on_tokens:
1121+
1122+
if current_batch[0]["depth"] >= self.max_depth and special_token != "[A_Response]":
1123+
continue
1124+
1125+
# Prepare for inputs
1126+
input_texts = [path["prompt"] + special_token for path in current_batch]
1127+
1128+
# Generate outputs
1129+
if special_token != "[A_Response]":
1130+
init_outputs = self.generator.generate(
1131+
input_list = input_texts,
1132+
return_raw_output = True,
1133+
**self.response_generation_params
1134+
)
1135+
else:
1136+
init_outputs = self.generator.generate(
1137+
input_list = input_texts,
1138+
return_raw_output = True,
1139+
**self.other_generation_params
1140+
)
1141+
1142+
# Decode outputs
1143+
decoded_outputs = [output.outputs[0].text for output in init_outputs]
1144+
# Initialize lists to collect queries for batch retrieval
1145+
queries_for_search = []
1146+
1147+
# Process outputs and prepare for retrieval
1148+
for i, decoded_output in enumerate(decoded_outputs):
1149+
current_path = current_batch[i]
1150+
decoded_output = decoded_output.replace("<s> ", "<s>")
1151+
1152+
if special_token == "[A_Response]":
1153+
pattern = r"(.*?)\[EOS\]"
1154+
matches = re.findall(pattern, decoded_output, re.DOTALL)
1155+
result = matches[-1].strip() if matches else "Unable to detect valid answer"
1156+
token_ids = init_outputs[i].outputs[0].token_ids[1:-1]
1157+
logprobs = init_outputs[i].outputs[0].logprobs[1:-1]
1158+
confidence = 0
1159+
for token_id, logprobs in zip(token_ids, logprobs):
1160+
logprob = logprobs[token_id].logprob
1161+
prob = math.exp(logprob)
1162+
confidence += prob
1163+
1164+
if len(token_ids) > 0:
1165+
confidence /= len(token_ids)
1166+
1167+
new_path = {
1168+
"prompt": input_texts[i] + decoded_output,
1169+
"depth": current_path["depth"] + 1,
1170+
"done": True,
1171+
"final_answer": result,
1172+
"confidence": confidence
1173+
}
1174+
final_outputs_batch[i].append(new_path)
1175+
else:
1176+
# Extract the query
1177+
pattern = r"(.*?)\[EOS\]"
1178+
matches = re.findall(pattern, decoded_output, re.DOTALL)
1179+
query_for_search = matches[-1].strip() if matches else "dummy"
1180+
queries_for_search.append(query_for_search)
1181+
1182+
# Perform batch retrieval
1183+
if queries_for_search:
1184+
batch_search_results = self.retriever.batch_search(queries_for_search)
1185+
1186+
for i, decoded_output in enumerate(decoded_outputs):
1187+
search_results = batch_search_results[i]
1188+
format_evidence = self.format_evidences(search_results)
1189+
new_prompt = decoded_output + "[R_Evidences]" + format_evidence + "[/R_Evidences]"
1190+
new_path = {
1191+
"prompt": input_texts[i] + new_prompt,
1192+
"depth": current_path["depth"] + 1,
1193+
"done": False,
1194+
}
1195+
paths_batch_dict[i].append(new_path)
1196+
1197+
final_outputs_batch_list = [final_outputs_batch[i] for i in range(len(initial_prompts_batch))]
1198+
1199+
return final_outputs_batch_list
1200+
1201+
def select_best_path_single_turn(self, final_outputs):
1202+
# After generating all paths, we can select the best answer
1203+
# Compute perplexity and confidence for each path
1204+
1205+
scores = []
1206+
for path in final_outputs:
1207+
confidence = path["confidence"]
1208+
path["confidence"] = confidence
1209+
scores.append((path, confidence))
1210+
1211+
# Select the path with the highest confidence
1212+
best_path = max(scores, key = lambda x: x[1])[0] # x[2] is confidence
1213+
pred = best_path["final_answer"]
1214+
1215+
return pred, best_path
1216+
1217+
def run(self, dataset: Dataset, do_eval = True):
1218+
preds = []
1219+
meta_results = []
1220+
1221+
from tqdm import tqdm
1222+
for i in tqdm(range(0, len(dataset), self.batch_size)):
1223+
batch_items = dataset[i : i + self.batch_size]
1224+
eval_datas = self.preprocess_eval_data(batch_items)
1225+
paths_batch = self.generate_tree_of_thoughts_batch(initial_prompts_batch = eval_datas)
1226+
for paths in paths_batch:
1227+
pred, best_path = self.select_best_path_single_turn(paths)
1228+
preds.append(pred)
1229+
meta_results.append(best_path)
1230+
1231+
1232+
dataset.update_output("paths", meta_results)
1233+
dataset.update_output("pred", preds)
1234+
1235+
dataset = self.evaluate(dataset, do_eval = do_eval)
1236+
return dataset
1237+

0 commit comments

Comments
 (0)