diff --git a/examples/methods/run_exp.py b/examples/methods/run_exp.py index fbe7795..5dbfe39 100644 --- a/examples/methods/run_exp.py +++ b/examples/methods/run_exp.py @@ -552,6 +552,32 @@ def adaptive(args): pipeline = AdaptivePipeline(config) result = pipeline.run(test_data) +def rqrag(args): + """ + Function to run the RQRAGPipeline. + """ + from flashrag.pipeline.active_pipeline import RQRAGPipeline + + save_note = "rqrag" + + config_dict = { + "save_note": save_note, + "gpu_id": args.gpu_id, + 'framework': 'vllm', + "dataset_name": args.dataset_name, + "split": args.split, + "retrieval_topk": 5, + "max_depth": 3 + } + + config = Config("my_config.yaml", config_dict) + + all_split = get_dataset(config) + test_data = all_split[args.split] + + pipeline = RQRAGPipeline(config, max_depth = args.max_depth) + result = pipeline.run(test_data) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Running exp") @@ -560,6 +586,7 @@ def adaptive(args): parser.add_argument("--dataset_name", type=str) parser.add_argument("--gpu_id", type=str) + func_dict = { "AAR-contriever": aar, "AAR-ANCE": aar, @@ -578,6 +605,7 @@ def adaptive(args): "ircot": ircot, "trace": trace, "adaptive": adaptive, + "rqrag": rqrag, } args = parser.parse_args() diff --git a/flashrag/pipeline/active_pipeline.py b/flashrag/pipeline/active_pipeline.py index a02f3c7..704e7e0 100644 --- a/flashrag/pipeline/active_pipeline.py +++ b/flashrag/pipeline/active_pipeline.py @@ -1,5 +1,6 @@ import re from tqdm import tqdm +import math import numpy as np from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse @@ -1014,3 +1015,223 @@ def run(self, dataset, do_eval=True, pred_process_fun=ircot_pred_parse): dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) return dataset + +class RQRAGPipeline(BasicPipeline): + expand_on_tokens = [ + "[S_Rewritten_Query]", + "[S_Decomposed_Query]", + "[S_Disambiguated_Query]", + "[A_Response]" + ] + + system_prompt = { + "qa": "Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response." + } + + response_generation_params = { + "temperature": 0, + "top_p": 0.9, + "stop": ["[EOS]", ""], + "skip_special_tokens": False, + "include_stop_str_in_output": True, + "logprobs": 1, + "spaces_between_special_tokens": False, + "max_tokens": 4096 + } + + other_generation_params = { + "temperature": 1, + "top_p": 0.9, + "stop": ["[EOS]", ""], + "skip_special_tokens": False, + "include_stop_str_in_output": True, + "logprobs": 1, + "spaces_between_special_tokens": False, + "max_tokens": 4096 + } + + from flashrag.dataset import Dataset, Item + from typing import List, Tuple + + def __init__( + self, + config: dict, + prompt_template = None, + retriever = None, + generator = None, + max_depth = 3, + batch_size = 32 + ): + super().__init__(config, prompt_template) + + + self.generator = generator if generator is not None else get_generator(config) + self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model_path"], padding_side = "left") + self.retriever = retriever if retriever is not None else get_retriever(config) + + self.max_depth = max_depth + self.batch_size = batch_size + + # Due to the low effiency of original method, it only supports vllm now. + + def preprocess_eval_data(self, items: List[Item]) -> List[str]: + eval_examples = [] + + for item in items: + eval_example = f"<|system|>\n{self.system_prompt['qa']}" + self.tokenizer.eos_token + "\n<|user|>\n" + item.question + self.tokenizer.eos_token + "\n" + eval_example += "<|assistant|>\n" + eval_examples.append(eval_example) + + return eval_examples + + def format_evidences(self, evidences: List[str]): + format_evidence = "" + for evidence in evidences: + title = evidence['contents'].split('\n')[0] + text = "\n".join(evidence['contents'].split('\n')[1:]) + format_evidence += f"Title: {title}\n" + format_evidence += f"Text: {text}\n" + return format_evidence + + def generate_tree_of_thoughts_batch(self, initial_prompts_batch: List[str]): + paths_batch_dict = { + idx: [{ + "prompt": initial_prompt, + "depth": 0, + "done": False + }] + for idx, initial_prompt in enumerate(initial_prompts_batch) + } + + final_outputs_batch = {idx: [] for idx in range(len(initial_prompts_batch))} + + while any(paths for paths in paths_batch_dict.values()): + current_batch = [] + for i, _ in paths_batch_dict.items(): + if paths_batch_dict[i]: + current_path = paths_batch_dict[i].pop(0) + current_batch.append(current_path) + else: + continue + + if not current_batch: + break + + for special_token in self.expand_on_tokens: + + if current_batch[0]["depth"] >= self.max_depth and special_token != "[A_Response]": + continue + + # Prepare for inputs + input_texts = [path["prompt"] + special_token for path in current_batch] + + # Generate outputs + if special_token != "[A_Response]": + init_outputs = self.generator.generate( + input_list = input_texts, + return_raw_output = True, + **self.response_generation_params + ) + else: + init_outputs = self.generator.generate( + input_list = input_texts, + return_raw_output = True, + **self.other_generation_params + ) + + # Decode outputs + decoded_outputs = [output.outputs[0].text for output in init_outputs] + # Initialize lists to collect queries for batch retrieval + queries_for_search = [] + + # Process outputs and prepare for retrieval + for i, decoded_output in enumerate(decoded_outputs): + current_path = current_batch[i] + decoded_output = decoded_output.replace(" ", "") + + if special_token == "[A_Response]": + pattern = r"(.*?)\[EOS\]" + matches = re.findall(pattern, decoded_output, re.DOTALL) + result = matches[-1].strip() if matches else "Unable to detect valid answer" + token_ids = init_outputs[i].outputs[0].token_ids[1:-1] + logprobs = init_outputs[i].outputs[0].logprobs[1:-1] + confidence = 0 + for token_id, logprobs in zip(token_ids, logprobs): + logprob = logprobs[token_id].logprob + prob = math.exp(logprob) + confidence += prob + + if len(token_ids) > 0: + confidence /= len(token_ids) + + new_path = { + "prompt": input_texts[i] + decoded_output, + "depth": current_path["depth"] + 1, + "done": True, + "final_answer": result, + "confidence": confidence + } + final_outputs_batch[i].append(new_path) + else: + # Extract the query + pattern = r"(.*?)\[EOS\]" + matches = re.findall(pattern, decoded_output, re.DOTALL) + query_for_search = matches[-1].strip() if matches else "dummy" + queries_for_search.append(query_for_search) + + # Perform batch retrieval + if queries_for_search: + batch_search_results = self.retriever.batch_search(queries_for_search) + + for i, decoded_output in enumerate(decoded_outputs): + search_results = batch_search_results[i] + format_evidence = self.format_evidences(search_results) + new_prompt = decoded_output + "[R_Evidences]" + format_evidence + "[/R_Evidences]" + new_path = { + "prompt": input_texts[i] + new_prompt, + "depth": current_path["depth"] + 1, + "done": False, + } + paths_batch_dict[i].append(new_path) + + final_outputs_batch_list = [final_outputs_batch[i] for i in range(len(initial_prompts_batch))] + + return final_outputs_batch_list + + def select_best_path_single_turn(self, final_outputs): + # After generating all paths, we can select the best answer + # Compute perplexity and confidence for each path + + scores = [] + for path in final_outputs: + confidence = path["confidence"] + path["confidence"] = confidence + scores.append((path, confidence)) + + # Select the path with the highest confidence + best_path = max(scores, key = lambda x: x[1])[0] # x[2] is confidence + pred = best_path["final_answer"] + + return pred, best_path + + def run(self, dataset: Dataset, do_eval = True): + preds = [] + meta_results = [] + + from tqdm import tqdm + for i in tqdm(range(0, len(dataset), self.batch_size)): + batch_items = dataset[i : i + self.batch_size] + eval_datas = self.preprocess_eval_data(batch_items) + paths_batch = self.generate_tree_of_thoughts_batch(initial_prompts_batch = eval_datas) + for paths in paths_batch: + pred, best_path = self.select_best_path_single_turn(paths) + preds.append(pred) + meta_results.append(best_path) + + + dataset.update_output("paths", meta_results) + dataset.update_output("pred", preds) + + dataset = self.evaluate(dataset, do_eval = do_eval) + return dataset +