Skip to content

Commit

Permalink
Merge pull request #108 from YuyaoZhangQAQ/feature/rqrag-support
Browse files Browse the repository at this point in the history
rqrag support
  • Loading branch information
ignorejjj authored Dec 8, 2024
2 parents 3cac6a8 + 6918785 commit b8bdf7a
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 0 deletions.
28 changes: 28 additions & 0 deletions examples/methods/run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,32 @@ def adaptive(args):
pipeline = AdaptivePipeline(config)
result = pipeline.run(test_data)

def rqrag(args):
"""
Function to run the RQRAGPipeline.
"""
from flashrag.pipeline.active_pipeline import RQRAGPipeline

save_note = "rqrag"

config_dict = {
"save_note": save_note,
"gpu_id": args.gpu_id,
'framework': 'vllm',
"dataset_name": args.dataset_name,
"split": args.split,
"retrieval_topk": 5,
"max_depth": 3
}

config = Config("my_config.yaml", config_dict)

all_split = get_dataset(config)
test_data = all_split[args.split]

pipeline = RQRAGPipeline(config, max_depth = args.max_depth)
result = pipeline.run(test_data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Running exp")
Expand All @@ -560,6 +586,7 @@ def adaptive(args):
parser.add_argument("--dataset_name", type=str)
parser.add_argument("--gpu_id", type=str)


func_dict = {
"AAR-contriever": aar,
"AAR-ANCE": aar,
Expand All @@ -578,6 +605,7 @@ def adaptive(args):
"ircot": ircot,
"trace": trace,
"adaptive": adaptive,
"rqrag": rqrag,
}

args = parser.parse_args()
Expand Down
221 changes: 221 additions & 0 deletions flashrag/pipeline/active_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from tqdm import tqdm
import math
import numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse
Expand Down Expand Up @@ -1014,3 +1015,223 @@ def run(self, dataset, do_eval=True, pred_process_fun=ircot_pred_parse):

dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
return dataset

class RQRAGPipeline(BasicPipeline):
expand_on_tokens = [
"[S_Rewritten_Query]",
"[S_Decomposed_Query]",
"[S_Disambiguated_Query]",
"[A_Response]"
]

system_prompt = {
"qa": "Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response."
}

response_generation_params = {
"temperature": 0,
"top_p": 0.9,
"stop": ["[EOS]", "</s>"],
"skip_special_tokens": False,
"include_stop_str_in_output": True,
"logprobs": 1,
"spaces_between_special_tokens": False,
"max_tokens": 4096
}

other_generation_params = {
"temperature": 1,
"top_p": 0.9,
"stop": ["[EOS]", "</s>"],
"skip_special_tokens": False,
"include_stop_str_in_output": True,
"logprobs": 1,
"spaces_between_special_tokens": False,
"max_tokens": 4096
}

from flashrag.dataset import Dataset, Item
from typing import List, Tuple

def __init__(
self,
config: dict,
prompt_template = None,
retriever = None,
generator = None,
max_depth = 3,
batch_size = 32
):
super().__init__(config, prompt_template)


self.generator = generator if generator is not None else get_generator(config)
self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model_path"], padding_side = "left")
self.retriever = retriever if retriever is not None else get_retriever(config)

self.max_depth = max_depth
self.batch_size = batch_size

# Due to the low effiency of original method, it only supports vllm now.

def preprocess_eval_data(self, items: List[Item]) -> List[str]:
eval_examples = []

for item in items:
eval_example = f"<s><|system|>\n{self.system_prompt['qa']}" + self.tokenizer.eos_token + "\n<|user|>\n" + item.question + self.tokenizer.eos_token + "\n"
eval_example += "<|assistant|>\n"
eval_examples.append(eval_example)

return eval_examples

def format_evidences(self, evidences: List[str]):
format_evidence = ""
for evidence in evidences:
title = evidence['contents'].split('\n')[0]
text = "\n".join(evidence['contents'].split('\n')[1:])
format_evidence += f"Title: {title}\n"
format_evidence += f"Text: {text}\n"
return format_evidence

def generate_tree_of_thoughts_batch(self, initial_prompts_batch: List[str]):
paths_batch_dict = {
idx: [{
"prompt": initial_prompt,
"depth": 0,
"done": False
}]
for idx, initial_prompt in enumerate(initial_prompts_batch)
}

final_outputs_batch = {idx: [] for idx in range(len(initial_prompts_batch))}

while any(paths for paths in paths_batch_dict.values()):
current_batch = []
for i, _ in paths_batch_dict.items():
if paths_batch_dict[i]:
current_path = paths_batch_dict[i].pop(0)
current_batch.append(current_path)
else:
continue

if not current_batch:
break

for special_token in self.expand_on_tokens:

if current_batch[0]["depth"] >= self.max_depth and special_token != "[A_Response]":
continue

# Prepare for inputs
input_texts = [path["prompt"] + special_token for path in current_batch]

# Generate outputs
if special_token != "[A_Response]":
init_outputs = self.generator.generate(
input_list = input_texts,
return_raw_output = True,
**self.response_generation_params
)
else:
init_outputs = self.generator.generate(
input_list = input_texts,
return_raw_output = True,
**self.other_generation_params
)

# Decode outputs
decoded_outputs = [output.outputs[0].text for output in init_outputs]
# Initialize lists to collect queries for batch retrieval
queries_for_search = []

# Process outputs and prepare for retrieval
for i, decoded_output in enumerate(decoded_outputs):
current_path = current_batch[i]
decoded_output = decoded_output.replace("<s> ", "<s>")

if special_token == "[A_Response]":
pattern = r"(.*?)\[EOS\]"
matches = re.findall(pattern, decoded_output, re.DOTALL)
result = matches[-1].strip() if matches else "Unable to detect valid answer"
token_ids = init_outputs[i].outputs[0].token_ids[1:-1]
logprobs = init_outputs[i].outputs[0].logprobs[1:-1]
confidence = 0
for token_id, logprobs in zip(token_ids, logprobs):
logprob = logprobs[token_id].logprob
prob = math.exp(logprob)
confidence += prob

if len(token_ids) > 0:
confidence /= len(token_ids)

new_path = {
"prompt": input_texts[i] + decoded_output,
"depth": current_path["depth"] + 1,
"done": True,
"final_answer": result,
"confidence": confidence
}
final_outputs_batch[i].append(new_path)
else:
# Extract the query
pattern = r"(.*?)\[EOS\]"
matches = re.findall(pattern, decoded_output, re.DOTALL)
query_for_search = matches[-1].strip() if matches else "dummy"
queries_for_search.append(query_for_search)

# Perform batch retrieval
if queries_for_search:
batch_search_results = self.retriever.batch_search(queries_for_search)

for i, decoded_output in enumerate(decoded_outputs):
search_results = batch_search_results[i]
format_evidence = self.format_evidences(search_results)
new_prompt = decoded_output + "[R_Evidences]" + format_evidence + "[/R_Evidences]"
new_path = {
"prompt": input_texts[i] + new_prompt,
"depth": current_path["depth"] + 1,
"done": False,
}
paths_batch_dict[i].append(new_path)

final_outputs_batch_list = [final_outputs_batch[i] for i in range(len(initial_prompts_batch))]

return final_outputs_batch_list

def select_best_path_single_turn(self, final_outputs):
# After generating all paths, we can select the best answer
# Compute perplexity and confidence for each path

scores = []
for path in final_outputs:
confidence = path["confidence"]
path["confidence"] = confidence
scores.append((path, confidence))

# Select the path with the highest confidence
best_path = max(scores, key = lambda x: x[1])[0] # x[2] is confidence
pred = best_path["final_answer"]

return pred, best_path

def run(self, dataset: Dataset, do_eval = True):
preds = []
meta_results = []

from tqdm import tqdm
for i in tqdm(range(0, len(dataset), self.batch_size)):
batch_items = dataset[i : i + self.batch_size]
eval_datas = self.preprocess_eval_data(batch_items)
paths_batch = self.generate_tree_of_thoughts_batch(initial_prompts_batch = eval_datas)
for paths in paths_batch:
pred, best_path = self.select_best_path_single_turn(paths)
preds.append(pred)
meta_results.append(best_path)


dataset.update_output("paths", meta_results)
dataset.update_output("pred", preds)

dataset = self.evaluate(dataset, do_eval = do_eval)
return dataset

0 comments on commit b8bdf7a

Please sign in to comment.