|
1 | 1 | import re
|
2 | 2 | from tqdm import tqdm
|
| 3 | +import math |
3 | 4 | import numpy as np
|
4 | 5 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
5 | 6 | from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse
|
@@ -1014,3 +1015,223 @@ def run(self, dataset, do_eval=True, pred_process_fun=ircot_pred_parse):
|
1014 | 1015 |
|
1015 | 1016 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
|
1016 | 1017 | return dataset
|
| 1018 | + |
| 1019 | +class RQRAGPipeline(BasicPipeline): |
| 1020 | + expand_on_tokens = [ |
| 1021 | + "[S_Rewritten_Query]", |
| 1022 | + "[S_Decomposed_Query]", |
| 1023 | + "[S_Disambiguated_Query]", |
| 1024 | + "[A_Response]" |
| 1025 | + ] |
| 1026 | + |
| 1027 | + system_prompt = { |
| 1028 | + "qa": "Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response." |
| 1029 | + } |
| 1030 | + |
| 1031 | + response_generation_params = { |
| 1032 | + "temperature": 0, |
| 1033 | + "top_p": 0.9, |
| 1034 | + "stop": ["[EOS]", "</s>"], |
| 1035 | + "skip_special_tokens": False, |
| 1036 | + "include_stop_str_in_output": True, |
| 1037 | + "logprobs": 1, |
| 1038 | + "spaces_between_special_tokens": False, |
| 1039 | + "max_tokens": 4096 |
| 1040 | + } |
| 1041 | + |
| 1042 | + other_generation_params = { |
| 1043 | + "temperature": 1, |
| 1044 | + "top_p": 0.9, |
| 1045 | + "stop": ["[EOS]", "</s>"], |
| 1046 | + "skip_special_tokens": False, |
| 1047 | + "include_stop_str_in_output": True, |
| 1048 | + "logprobs": 1, |
| 1049 | + "spaces_between_special_tokens": False, |
| 1050 | + "max_tokens": 4096 |
| 1051 | + } |
| 1052 | + |
| 1053 | + from flashrag.dataset import Dataset, Item |
| 1054 | + from typing import List, Tuple |
| 1055 | + |
| 1056 | + def __init__( |
| 1057 | + self, |
| 1058 | + config: dict, |
| 1059 | + prompt_template = None, |
| 1060 | + retriever = None, |
| 1061 | + generator = None, |
| 1062 | + max_depth = 3, |
| 1063 | + batch_size = 32 |
| 1064 | + ): |
| 1065 | + super().__init__(config, prompt_template) |
| 1066 | + |
| 1067 | + |
| 1068 | + self.generator = generator if generator is not None else get_generator(config) |
| 1069 | + self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model_path"], padding_side = "left") |
| 1070 | + self.retriever = retriever if retriever is not None else get_retriever(config) |
| 1071 | + |
| 1072 | + self.max_depth = max_depth |
| 1073 | + self.batch_size = batch_size |
| 1074 | + |
| 1075 | + # Due to the low effiency of original method, it only supports vllm now. |
| 1076 | + |
| 1077 | + def preprocess_eval_data(self, items: List[Item]) -> List[str]: |
| 1078 | + eval_examples = [] |
| 1079 | + |
| 1080 | + for item in items: |
| 1081 | + eval_example = f"<s><|system|>\n{self.system_prompt['qa']}" + self.tokenizer.eos_token + "\n<|user|>\n" + item.question + self.tokenizer.eos_token + "\n" |
| 1082 | + eval_example += "<|assistant|>\n" |
| 1083 | + eval_examples.append(eval_example) |
| 1084 | + |
| 1085 | + return eval_examples |
| 1086 | + |
| 1087 | + def format_evidences(self, evidences: List[str]): |
| 1088 | + format_evidence = "" |
| 1089 | + for evidence in evidences: |
| 1090 | + title = evidence['contents'].split('\n')[0] |
| 1091 | + text = "\n".join(evidence['contents'].split('\n')[1:]) |
| 1092 | + format_evidence += f"Title: {title}\n" |
| 1093 | + format_evidence += f"Text: {text}\n" |
| 1094 | + return format_evidence |
| 1095 | + |
| 1096 | + def generate_tree_of_thoughts_batch(self, initial_prompts_batch: List[str]): |
| 1097 | + paths_batch_dict = { |
| 1098 | + idx: [{ |
| 1099 | + "prompt": initial_prompt, |
| 1100 | + "depth": 0, |
| 1101 | + "done": False |
| 1102 | + }] |
| 1103 | + for idx, initial_prompt in enumerate(initial_prompts_batch) |
| 1104 | + } |
| 1105 | + |
| 1106 | + final_outputs_batch = {idx: [] for idx in range(len(initial_prompts_batch))} |
| 1107 | + |
| 1108 | + while any(paths for paths in paths_batch_dict.values()): |
| 1109 | + current_batch = [] |
| 1110 | + for i, _ in paths_batch_dict.items(): |
| 1111 | + if paths_batch_dict[i]: |
| 1112 | + current_path = paths_batch_dict[i].pop(0) |
| 1113 | + current_batch.append(current_path) |
| 1114 | + else: |
| 1115 | + continue |
| 1116 | + |
| 1117 | + if not current_batch: |
| 1118 | + break |
| 1119 | + |
| 1120 | + for special_token in self.expand_on_tokens: |
| 1121 | + |
| 1122 | + if current_batch[0]["depth"] >= self.max_depth and special_token != "[A_Response]": |
| 1123 | + continue |
| 1124 | + |
| 1125 | + # Prepare for inputs |
| 1126 | + input_texts = [path["prompt"] + special_token for path in current_batch] |
| 1127 | + |
| 1128 | + # Generate outputs |
| 1129 | + if special_token != "[A_Response]": |
| 1130 | + init_outputs = self.generator.generate( |
| 1131 | + input_list = input_texts, |
| 1132 | + return_raw_output = True, |
| 1133 | + **self.response_generation_params |
| 1134 | + ) |
| 1135 | + else: |
| 1136 | + init_outputs = self.generator.generate( |
| 1137 | + input_list = input_texts, |
| 1138 | + return_raw_output = True, |
| 1139 | + **self.other_generation_params |
| 1140 | + ) |
| 1141 | + |
| 1142 | + # Decode outputs |
| 1143 | + decoded_outputs = [output.outputs[0].text for output in init_outputs] |
| 1144 | + # Initialize lists to collect queries for batch retrieval |
| 1145 | + queries_for_search = [] |
| 1146 | + |
| 1147 | + # Process outputs and prepare for retrieval |
| 1148 | + for i, decoded_output in enumerate(decoded_outputs): |
| 1149 | + current_path = current_batch[i] |
| 1150 | + decoded_output = decoded_output.replace("<s> ", "<s>") |
| 1151 | + |
| 1152 | + if special_token == "[A_Response]": |
| 1153 | + pattern = r"(.*?)\[EOS\]" |
| 1154 | + matches = re.findall(pattern, decoded_output, re.DOTALL) |
| 1155 | + result = matches[-1].strip() if matches else "Unable to detect valid answer" |
| 1156 | + token_ids = init_outputs[i].outputs[0].token_ids[1:-1] |
| 1157 | + logprobs = init_outputs[i].outputs[0].logprobs[1:-1] |
| 1158 | + confidence = 0 |
| 1159 | + for token_id, logprobs in zip(token_ids, logprobs): |
| 1160 | + logprob = logprobs[token_id].logprob |
| 1161 | + prob = math.exp(logprob) |
| 1162 | + confidence += prob |
| 1163 | + |
| 1164 | + if len(token_ids) > 0: |
| 1165 | + confidence /= len(token_ids) |
| 1166 | + |
| 1167 | + new_path = { |
| 1168 | + "prompt": input_texts[i] + decoded_output, |
| 1169 | + "depth": current_path["depth"] + 1, |
| 1170 | + "done": True, |
| 1171 | + "final_answer": result, |
| 1172 | + "confidence": confidence |
| 1173 | + } |
| 1174 | + final_outputs_batch[i].append(new_path) |
| 1175 | + else: |
| 1176 | + # Extract the query |
| 1177 | + pattern = r"(.*?)\[EOS\]" |
| 1178 | + matches = re.findall(pattern, decoded_output, re.DOTALL) |
| 1179 | + query_for_search = matches[-1].strip() if matches else "dummy" |
| 1180 | + queries_for_search.append(query_for_search) |
| 1181 | + |
| 1182 | + # Perform batch retrieval |
| 1183 | + if queries_for_search: |
| 1184 | + batch_search_results = self.retriever.batch_search(queries_for_search) |
| 1185 | + |
| 1186 | + for i, decoded_output in enumerate(decoded_outputs): |
| 1187 | + search_results = batch_search_results[i] |
| 1188 | + format_evidence = self.format_evidences(search_results) |
| 1189 | + new_prompt = decoded_output + "[R_Evidences]" + format_evidence + "[/R_Evidences]" |
| 1190 | + new_path = { |
| 1191 | + "prompt": input_texts[i] + new_prompt, |
| 1192 | + "depth": current_path["depth"] + 1, |
| 1193 | + "done": False, |
| 1194 | + } |
| 1195 | + paths_batch_dict[i].append(new_path) |
| 1196 | + |
| 1197 | + final_outputs_batch_list = [final_outputs_batch[i] for i in range(len(initial_prompts_batch))] |
| 1198 | + |
| 1199 | + return final_outputs_batch_list |
| 1200 | + |
| 1201 | + def select_best_path_single_turn(self, final_outputs): |
| 1202 | + # After generating all paths, we can select the best answer |
| 1203 | + # Compute perplexity and confidence for each path |
| 1204 | + |
| 1205 | + scores = [] |
| 1206 | + for path in final_outputs: |
| 1207 | + confidence = path["confidence"] |
| 1208 | + path["confidence"] = confidence |
| 1209 | + scores.append((path, confidence)) |
| 1210 | + |
| 1211 | + # Select the path with the highest confidence |
| 1212 | + best_path = max(scores, key = lambda x: x[1])[0] # x[2] is confidence |
| 1213 | + pred = best_path["final_answer"] |
| 1214 | + |
| 1215 | + return pred, best_path |
| 1216 | + |
| 1217 | + def run(self, dataset: Dataset, do_eval = True): |
| 1218 | + preds = [] |
| 1219 | + meta_results = [] |
| 1220 | + |
| 1221 | + from tqdm import tqdm |
| 1222 | + for i in tqdm(range(0, len(dataset), self.batch_size)): |
| 1223 | + batch_items = dataset[i : i + self.batch_size] |
| 1224 | + eval_datas = self.preprocess_eval_data(batch_items) |
| 1225 | + paths_batch = self.generate_tree_of_thoughts_batch(initial_prompts_batch = eval_datas) |
| 1226 | + for paths in paths_batch: |
| 1227 | + pred, best_path = self.select_best_path_single_turn(paths) |
| 1228 | + preds.append(pred) |
| 1229 | + meta_results.append(best_path) |
| 1230 | + |
| 1231 | + |
| 1232 | + dataset.update_output("paths", meta_results) |
| 1233 | + dataset.update_output("pred", preds) |
| 1234 | + |
| 1235 | + dataset = self.evaluate(dataset, do_eval = do_eval) |
| 1236 | + return dataset |
| 1237 | + |
0 commit comments