Skip to content

Commit 6c0fb06

Browse files
authored
Update active_pipeline.py
1 parent a8fdb1b commit 6c0fb06

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

flashrag/pipeline/active_pipeline.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from tqdm import tqdm
3+
import math
34
import numpy as np
45
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
56
from flashrag.utils import get_retriever, get_generator, selfask_pred_parse, ircot_pred_parse
@@ -1015,7 +1016,6 @@ def run(self, dataset, do_eval=True, pred_process_fun=ircot_pred_parse):
10151016
dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
10161017
return dataset
10171018

1018-
import math
10191019
class RQRAGPipeline(BasicPipeline):
10201020
expand_on_tokens = [
10211021
"[S_Rewritten_Query]",
@@ -1063,19 +1063,15 @@ def __init__(
10631063
batch_size = 32
10641064
):
10651065
super().__init__(config, prompt_template)
1066-
from flashrag.generator import VLLMGenerator
1067-
from flashrag.retriever import BaseRetriever
1068-
1069-
self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model_path"], padding_side = "left")
1066+
1067+
10701068
self.generator = generator if generator is not None else get_generator(config)
1069+
self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model_path"], padding_side = "left")
10711070
self.retriever = retriever if retriever is not None else get_retriever(config)
10721071

10731072
self.max_depth = max_depth
10741073
self.batch_size = batch_size
10751074

1076-
self.generator: VLLMGenerator
1077-
self.retriever: BaseRetriever
1078-
10791075
# Due to the low effiency of original method, it only supports vllm now.
10801076

10811077
def preprocess_eval_data(self, items: List[Item]) -> List[str]:
@@ -1238,4 +1234,4 @@ def run(self, dataset: Dataset, do_eval = True):
12381234

12391235
dataset = self.evaluate(dataset, do_eval = do_eval)
12401236
return dataset
1241-
1237+

0 commit comments

Comments
 (0)