1
1
import re
2
2
from tqdm import tqdm
3
+ import math
3
4
import numpy as np
4
5
from transformers import AutoTokenizer , PreTrainedTokenizer , PreTrainedTokenizerFast
5
6
from flashrag .utils import get_retriever , get_generator , selfask_pred_parse , ircot_pred_parse
@@ -1015,7 +1016,6 @@ def run(self, dataset, do_eval=True, pred_process_fun=ircot_pred_parse):
1015
1016
dataset = self .evaluate (dataset , do_eval = do_eval , pred_process_fun = pred_process_fun )
1016
1017
return dataset
1017
1018
1018
- import math
1019
1019
class RQRAGPipeline (BasicPipeline ):
1020
1020
expand_on_tokens = [
1021
1021
"[S_Rewritten_Query]" ,
@@ -1063,19 +1063,15 @@ def __init__(
1063
1063
batch_size = 32
1064
1064
):
1065
1065
super ().__init__ (config , prompt_template )
1066
- from flashrag .generator import VLLMGenerator
1067
- from flashrag .retriever import BaseRetriever
1068
-
1069
- self .tokenizer = AutoTokenizer .from_pretrained (config ["generator_model_path" ], padding_side = "left" )
1066
+
1067
+
1070
1068
self .generator = generator if generator is not None else get_generator (config )
1069
+ self .tokenizer = AutoTokenizer .from_pretrained (config ["generator_model_path" ], padding_side = "left" )
1071
1070
self .retriever = retriever if retriever is not None else get_retriever (config )
1072
1071
1073
1072
self .max_depth = max_depth
1074
1073
self .batch_size = batch_size
1075
1074
1076
- self .generator : VLLMGenerator
1077
- self .retriever : BaseRetriever
1078
-
1079
1075
# Due to the low effiency of original method, it only supports vllm now.
1080
1076
1081
1077
def preprocess_eval_data (self , items : List [Item ]) -> List [str ]:
@@ -1238,4 +1234,4 @@ def run(self, dataset: Dataset, do_eval = True):
1238
1234
1239
1235
dataset = self .evaluate (dataset , do_eval = do_eval )
1240
1236
return dataset
1241
-
1237
+
0 commit comments