Skip to content

Commit fd60d31

Browse files
Merge pull request #130 from gomate-community/pipeline
Pipeline
2 parents d1ad740 + b251b17 commit fd60d31

File tree

7 files changed

+476
-322
lines changed

7 files changed

+476
-322
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ examples/datasets/papers
3030
examples/download/models
3131
.gradio
3232
examples/datasets/arxiv/papers
33-
examples/projects/arxiv/papers
33+
examples/projects/arxiv/papers
34+
trustrag/modules/deepsearch/.env
35+
*.env

docs/mysql.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
## mysql部署
2+
3+
```bash
4+
docker stop mysql
5+
6+
docker rm -f mysql
7+
8+
docker run --name mysql \
9+
-p 3306:3306 \
10+
--restart always \
11+
-v G:/Ubuntu_WSL/rag-middlewares/mysql/data:/var/lib/mysql \
12+
-e MYSQL_ROOT_PASSWORD=123456 \
13+
-d mysql:latest
14+
15+
```

trustrag/config/config_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ class ConfigLoader:
1010

1111
_instance = None
1212
_config = None
13-
14-
def __new__(cls):
13+
14+
def __new__(cls, *args, **kwargs):
1515
"""单例模式,确保只有一个配置实例"""
1616
if cls._instance is None:
1717
cls._instance = super(ConfigLoader, cls).__new__(cls)
1818
return cls._instance
19-
19+
2020
def __init__(self,config_path):
2121
"""初始化配置加载器"""
2222
self.config_path=config_path

trustrag/modules/citation/citation_match_llm.json

Lines changed: 260 additions & 89 deletions
Large diffs are not rendered by default.

trustrag/modules/citation/citation_match_llm_res.json

Lines changed: 150 additions & 188 deletions
Large diffs are not rendered by default.

trustrag/modules/citation/llm_citation.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
2+
import re
23
from typing import List
3-
import re
4+
45
import jieba
56
import loguru
67

@@ -15,7 +16,7 @@ def cut(self, para: str):
1516

1617
# 定义结束符号列表
1718
# end_symbols = ['。', '!', '?', '…', ';', '\n'] # sent
18-
end_symbols = ['。', '!', '?', '…', ';', '\n']# para
19+
end_symbols = ['。', '!', '?', '…', ';', '\n'] # para
1920

2021
# 定义引号对
2122
quote_pairs = {'"': '"', "'": "'", '「': '」', '『': '』'}
@@ -102,8 +103,7 @@ def highlight_common_substrings(self, response_content, select_content, min_leng
102103
best_match_positions = [[0, len(select_content) - 1]]
103104
return best_match_positions
104105

105-
106-
def cal_common_ration(self,response,evidence):
106+
def cal_common_ration(self, response, evidence):
107107
"""
108108
计算答案中的段落与匹配证据的相似度,或者重合度,直接居于共现词的比例
109109
"""
@@ -114,8 +114,7 @@ def cal_common_ration(self,response,evidence):
114114
ratio = len(overlap) / sentence_seg_cut_length
115115
return ratio
116116

117-
118-
def extract_citations(self,response:str=None):
117+
def extract_citations(self, response: str = None):
119118
"""
120119
xxx[1]xxx[2],
121120
find all citation patterns like [number]
@@ -166,6 +165,7 @@ def extract_citations(self,response:str=None):
166165
"citations": citations,
167166
"parsed_result": parsed_result
168167
}
168+
169169
def ground_response(
170170
self,
171171
question: str,
@@ -189,37 +189,34 @@ def ground_response(
189189

190190
# Save to JSON file
191191
try:
192-
output_file = "/home/yanqiang/code/citation_match_llm.json"
192+
output_file = "/home/yanqiang/code/citation_match_llm_res.json"
193193
with open(output_file, 'w', encoding='utf-8') as f:
194194
json.dump(json_data, f, ensure_ascii=False, indent=4)
195195
except Exception as e:
196-
print(json_data)
197196
output_file = "citation_match_llm_res.json"
198197
with open(output_file, 'w', encoding='utf-8') as f:
199-
loguru.logger.info(json_data)
198+
# loguru.logger.info(json_data)
200199
json.dump(json_data, f, ensure_ascii=False, indent=4)
201200
loguru.logger.info(f"Parameters saved to {output_file}")
202-
citation_result=self.extract_citations(response=response)
203-
parsed_result=citation_result["parsed_result"]
204-
print(citation_result)
201+
citation_result = self.extract_citations(response=response)
202+
parsed_result = citation_result["parsed_result"]
205203

206204
quote_list = []
207-
208-
for idx,citation_item in enumerate(parsed_result):
209-
#todo:判断citation_item的类型,是text还是citation,
205+
existed_citations=[]
206+
citation_indices_map = {}
207+
start_indices = 0
208+
for idx, citation_item in enumerate(parsed_result):
209+
# todo:判断citation_item的类型,是text还是citation,
210210
# 如果当前citation_item是text,判断下一个类型是否为是citation,如果为citation,那么best_idx等于下面:
211211
# best_idx=parsed_result[idx+1]["index"]
212-
if idx<=len(parsed_result)-2:
212+
if idx <= len(parsed_result) - 2:
213213
if citation_item["type"] == "text":
214-
if parsed_result[idx+1]["type"] == "citation":
215-
216-
best_idx=parsed_result[idx+1]["index"]# 这个是selected_idx的真实引号+1,例如38
217-
best_idx=selected_idx.index((int(best_idx)-1))#
218-
219-
print(best_idx)
220-
response_content=citation_item["content"]
221-
select_content=selected_docs[best_idx]["content"]
222-
214+
if parsed_result[idx + 1]["type"] == "citation":
215+
raw_idx = parsed_result[idx + 1]["index"] # 这个是selected_idx的真实引号+1,例如38
216+
best_idx = selected_idx.index((int(raw_idx) - 1)) #
217+
# loguru.logger.info(f"raw_idx:{raw_idx},best_idx:{best_idx}")
218+
response_content = citation_item["content"]
219+
select_content = selected_docs[best_idx]["content"]
223220
highlighted_start_end = self.highlight_common_substrings(response_content, select_content)
224221
group_item = {
225222
"doc_id": selected_docs[best_idx]["doc_id"],
@@ -229,30 +226,39 @@ def ground_response(
229226
"doc_title": selected_docs[best_idx]["newsinfo"]["title"],
230227
# "chk_content": selected_docs[best_idx]['content'],
231228
"chk_content": select_content,
232-
"best_ratio": self.cal_common_ration(response_content,select_content),
229+
"best_ratio": self.cal_common_ration(response_content, select_content),
233230
"highlight": highlighted_start_end,
234231
}
235-
236232
group_data = {
237233
"doc_list": [group_item],
238234
"chk_content": group_item["chk_content"],
239235
"highlight": group_item["highlight"],
240236
}
241-
quote_list.append(group_data)
242-
243-
244-
response_result=''.join([item["content"] for item in citation_result["parsed_result"]])
237+
if start_indices not in citation_result["citations"] and group_data["chk_content"] not in existed_citations:
238+
quote_list.append(group_data)
239+
existed_citations.append(group_data["chk_content"])
240+
citation_indices_map[raw_idx] = start_indices
241+
start_indices += 1
242+
243+
loguru.logger.info(citation_indices_map)
244+
loguru.logger.info(len(quote_list))
245+
final_responses=[]
246+
for item in citation_result["parsed_result"]:
247+
if item["type"] == "text":
248+
final_responses.append(item["content"])
249+
else:
250+
citation_ind=citation_indices_map[item["index"]]+1
251+
final_responses.append(f"[{citation_ind}]")
252+
response_result = ''.join(final_responses)
245253
data = {'result': response_result, 'quote_list': quote_list, 'summary': ''}
246-
247254
# Save to JSON file
248255
json_data['result'] = response_result
249256
json_data['quote_list'] = quote_list
250-
output_file = "citation_match_llm_res.json"
257+
# loguru.logger.info(response_result)
258+
251259
with open(output_file, 'w', encoding='utf-8') as f:
252260
json.dump(json_data, f, ensure_ascii=False, indent=4)
253261
loguru.logger.info(f"Parameters saved to {output_file}")
254-
255-
256262
return data
257263

258264

trustrag/modules/document/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
import os
1313
import pathlib
1414
import re
15-
import chardet
16-
17-
import tiktoken
15+
from typing import Union
1816

19-
import pathlib
17+
import chardet
2018

2119
# 获取当前文件所在的路径
2220
current_path = pathlib.Path(__file__).resolve()
@@ -34,7 +32,6 @@
3432
project_root_str = str(project_root)
3533
print(f"项目根目录为: {project_root_str}")
3634

37-
3835
PROJECT_BASE = project_root_str
3936
all_codecs = [
4037
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
@@ -144,6 +141,8 @@ def findMaxTm(fnm):
144141
except Exception as e:
145142
pass
146143
return m
144+
145+
147146
def get_encoding(file: Union[str, bytes]) -> str:
148147
"""
149148
Detects the encoding of a given file.
@@ -158,7 +157,6 @@ def get_encoding(file: Union[str, bytes]) -> str:
158157
tmp = chardet.detect(f.read())
159158
return tmp['encoding']
160159

161-
162160
# # https://stackoverflow.com/questions/76106366/how-to-use-tiktoken-in-offline-mode-computer
163161
# tiktoken_cache_dir = "/data/users/searchgpt/yq/GoMate/data/docs"
164162
# os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir

0 commit comments

Comments
 (0)