11import json
2+ import re
23from typing import List
3- import re
4+
45import jieba
56import loguru
67
@@ -15,7 +16,7 @@ def cut(self, para: str):
1516
1617 # 定义结束符号列表
1718 # end_symbols = ['。', '!', '?', '…', ';', '\n'] # sent
18- end_symbols = ['。' , '!' , '?' , '…' , ';' , '\n ' ]# para
19+ end_symbols = ['。' , '!' , '?' , '…' , ';' , '\n ' ] # para
1920
2021 # 定义引号对
2122 quote_pairs = {'"' : '"' , "'" : "'" , '「' : '」' , '『' : '』' }
@@ -102,8 +103,7 @@ def highlight_common_substrings(self, response_content, select_content, min_leng
102103 best_match_positions = [[0 , len (select_content ) - 1 ]]
103104 return best_match_positions
104105
105-
106- def cal_common_ration (self ,response ,evidence ):
106+ def cal_common_ration (self , response , evidence ):
107107 """
108108 计算答案中的段落与匹配证据的相似度,或者重合度,直接居于共现词的比例
109109 """
@@ -114,8 +114,7 @@ def cal_common_ration(self,response,evidence):
114114 ratio = len (overlap ) / sentence_seg_cut_length
115115 return ratio
116116
117-
118- def extract_citations (self ,response :str = None ):
117+ def extract_citations (self , response : str = None ):
119118 """
120119 xxx[1]xxx[2],
121120 find all citation patterns like [number]
@@ -166,6 +165,7 @@ def extract_citations(self,response:str=None):
166165 "citations" : citations ,
167166 "parsed_result" : parsed_result
168167 }
168+
169169 def ground_response (
170170 self ,
171171 question : str ,
@@ -189,37 +189,34 @@ def ground_response(
189189
190190 # Save to JSON file
191191 try :
192- output_file = "/home/yanqiang/code/citation_match_llm .json"
192+ output_file = "/home/yanqiang/code/citation_match_llm_res .json"
193193 with open (output_file , 'w' , encoding = 'utf-8' ) as f :
194194 json .dump (json_data , f , ensure_ascii = False , indent = 4 )
195195 except Exception as e :
196- print (json_data )
197196 output_file = "citation_match_llm_res.json"
198197 with open (output_file , 'w' , encoding = 'utf-8' ) as f :
199- loguru .logger .info (json_data )
198+ # loguru.logger.info(json_data)
200199 json .dump (json_data , f , ensure_ascii = False , indent = 4 )
201200 loguru .logger .info (f"Parameters saved to { output_file } " )
202- citation_result = self .extract_citations (response = response )
203- parsed_result = citation_result ["parsed_result" ]
204- print (citation_result )
201+ citation_result = self .extract_citations (response = response )
202+ parsed_result = citation_result ["parsed_result" ]
205203
206204 quote_list = []
207-
208- for idx ,citation_item in enumerate (parsed_result ):
209- #todo:判断citation_item的类型,是text还是citation,
205+ existed_citations = []
206+ citation_indices_map = {}
207+ start_indices = 0
208+ for idx , citation_item in enumerate (parsed_result ):
209+ # todo:判断citation_item的类型,是text还是citation,
210210 # 如果当前citation_item是text,判断下一个类型是否为是citation,如果为citation,那么best_idx等于下面:
211211 # best_idx=parsed_result[idx+1]["index"]
212- if idx <= len (parsed_result )- 2 :
212+ if idx <= len (parsed_result ) - 2 :
213213 if citation_item ["type" ] == "text" :
214- if parsed_result [idx + 1 ]["type" ] == "citation" :
215-
216- best_idx = parsed_result [idx + 1 ]["index" ]# 这个是selected_idx的真实引号+1,例如38
217- best_idx = selected_idx .index ((int (best_idx )- 1 ))#
218-
219- print (best_idx )
220- response_content = citation_item ["content" ]
221- select_content = selected_docs [best_idx ]["content" ]
222-
214+ if parsed_result [idx + 1 ]["type" ] == "citation" :
215+ raw_idx = parsed_result [idx + 1 ]["index" ] # 这个是selected_idx的真实引号+1,例如38
216+ best_idx = selected_idx .index ((int (raw_idx ) - 1 )) #
217+ # loguru.logger.info(f"raw_idx:{raw_idx},best_idx:{best_idx}")
218+ response_content = citation_item ["content" ]
219+ select_content = selected_docs [best_idx ]["content" ]
223220 highlighted_start_end = self .highlight_common_substrings (response_content , select_content )
224221 group_item = {
225222 "doc_id" : selected_docs [best_idx ]["doc_id" ],
@@ -229,30 +226,39 @@ def ground_response(
229226 "doc_title" : selected_docs [best_idx ]["newsinfo" ]["title" ],
230227 # "chk_content": selected_docs[best_idx]['content'],
231228 "chk_content" : select_content ,
232- "best_ratio" : self .cal_common_ration (response_content ,select_content ),
229+ "best_ratio" : self .cal_common_ration (response_content , select_content ),
233230 "highlight" : highlighted_start_end ,
234231 }
235-
236232 group_data = {
237233 "doc_list" : [group_item ],
238234 "chk_content" : group_item ["chk_content" ],
239235 "highlight" : group_item ["highlight" ],
240236 }
241- quote_list .append (group_data )
242-
243-
244- response_result = '' .join ([item ["content" ] for item in citation_result ["parsed_result" ]])
237+ if start_indices not in citation_result ["citations" ] and group_data ["chk_content" ] not in existed_citations :
238+ quote_list .append (group_data )
239+ existed_citations .append (group_data ["chk_content" ])
240+ citation_indices_map [raw_idx ] = start_indices
241+ start_indices += 1
242+
243+ loguru .logger .info (citation_indices_map )
244+ loguru .logger .info (len (quote_list ))
245+ final_responses = []
246+ for item in citation_result ["parsed_result" ]:
247+ if item ["type" ] == "text" :
248+ final_responses .append (item ["content" ])
249+ else :
250+ citation_ind = citation_indices_map [item ["index" ]]+ 1
251+ final_responses .append (f"[{ citation_ind } ]" )
252+ response_result = '' .join (final_responses )
245253 data = {'result' : response_result , 'quote_list' : quote_list , 'summary' : '' }
246-
247254 # Save to JSON file
248255 json_data ['result' ] = response_result
249256 json_data ['quote_list' ] = quote_list
250- output_file = "citation_match_llm_res.json"
257+ # loguru.logger.info(response_result)
258+
251259 with open (output_file , 'w' , encoding = 'utf-8' ) as f :
252260 json .dump (json_data , f , ensure_ascii = False , indent = 4 )
253261 loguru .logger .info (f"Parameters saved to { output_file } " )
254-
255-
256262 return data
257263
258264
0 commit comments