Skip to content

Commit c9e1273

Browse files
Merge pull request #126 from gomate-community/pipeline
Pipeline
2 parents 2283ddb + d963b29 commit c9e1273

File tree

11 files changed

+337
-14
lines changed

11 files changed

+337
-14
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@ examples/rag/indexs
2525
examples/rag/mobile_rag.py
2626
**/.ipynb_checkpoints/
2727
.virtual_documents/
28-
examples/retrievers/dense_cache
28+
examples/retrievers/dense_cache
29+
examples/datasets/papers
30+
examples/download/models
31+
.gradio

README_zh.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ DeepResearch 框架通过分层查询、递归迭代以及智能决策等步骤
3434
系统会根据如下条件判断是否继续执行:
3535
1. **Token 预算是否超出**
3636
2. **动作深度是否超出**
37-
>如果满足上述条件,则终止查询并直接返回答案;否则进入递归执行步骤。
37+
>如果满足上述条件,则终止查询并直接返回答案;否则进入递归执行步骤。
3838
3939
3. 递归执行步骤
4040
在递归执行过程中,系统执行信息检索、模型推理及上下文处理等任务
@@ -44,7 +44,7 @@ DeepResearch 框架通过分层查询、递归迭代以及智能决策等步骤
4444
- **递归遍历**
4545
- **深度优先搜索**
4646
-**模型推理**
47-
>系统进行模型推理,通过系统提示和上下文理解来判断下一步动作。
47+
>系统进行模型推理,通过系统提示和上下文理解来判断下一步动作。
4848
4. 动作类型判定
4949
根据推理结果,系统决定下一步执行的动作类型:
5050
- **answer**:回答动作
@@ -53,7 +53,7 @@ DeepResearch 框架通过分层查询、递归迭代以及智能决策等步骤
5353
- **read**:阅读动作
5454
- **coding**:代码动作
5555

56-
>这些动作会影响上下文,并不断更新系统状态。
56+
>这些动作会影响上下文,并不断更新系统状态。
5757
5858
5. 结果反馈
5959
根据最终的动作类型,系统执行相应的任务,并将结果返回给用户,完成整个流程。

app.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import pandas as pd
1818

1919
from trustrag.applications.rag_openai import RagApplication, ApplicationConfig
20-
from trustrag.modules.reranker.bge_reranker import BgeRerankerConfig
2120
from trustrag.modules.retrieval.dense_retriever import DenseRetrieverConfig
2221
from datetime import datetime
2322
import pytz
@@ -533,7 +532,7 @@ def predict(question,
533532
)
534533
with gr.Column(scale=4):
535534
with gr.Row():
536-
chatbot = gr.Chatbot(label='TrustRAG Application').style(height=650)
535+
chatbot = gr.Chatbot(label='TrustRAG Application', height=650)
537536
with gr.Row():
538537
message = gr.Textbox(label='Please enter a question')
539538
with gr.Row():
@@ -583,13 +582,16 @@ def predict(question,
583582
state
584583
],
585584
outputs=[message, chatbot, state, search, rewrite] + checkbox_outputs)
585+
with gr.Tab("\N{book} DeepRsearch"):
586+
with gr.Row():
587+
gr.Markdown(
588+
""">Remind:[TrustRAG Application](https://github.com/gomate-community/TrustRAG/issues)If you have any questions, please provide feedback in [Github Issue区](https://github.com/gomate-community/TrustRAG/issues) .""")
586589

587-
demo.queue(concurrency_count=2).launch(
590+
demo.queue(max_size=2).launch(
588591
server_name='0.0.0.0',
589592
server_port=7860,
590593
share=True,
591594
show_error=True,
592595
debug=True,
593-
enable_queue=True,
594596
inbrowser=False,
595597
)

config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
"services": {
33
"dmx": {
44
"base_url": "https://www.dmxapi.com/v1",
5-
"api_key": "sk-gDbFoQAYz9pwqBsH0aPA1H8DN9s0B9F3vWNjjPcijRBFjk7f",
5+
"api_key": "sk-xx",
66
"description": "DMX API 服务"
77
},
88
"rerank": {
99
"base_url": "http://localhost:3600",
10-
"api_key": "sk-XTcBLdakFcZjdQTt7e29Ca9bF8F1495dB447E3Af023cF4E6",
10+
"api_key": "sk-xxx",
1111
"description": "重排序服务"
1212
}
1313
},
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import arxiv
2+
import os
3+
import json
4+
import time
5+
from tqdm import tqdm
6+
import logging
7+
from datetime import datetime
8+
9+
def process_metadata(result):
10+
"""
11+
将ArXiv结果对象转换为结构化的字典
12+
13+
Args:
14+
result: arxiv.Result对象
15+
16+
Returns:
17+
dict: 结构化的元数据字典
18+
"""
19+
metadata = {
20+
"entry_id": result.entry_id,
21+
"updated": str(result.updated),
22+
"published": str(result.published),
23+
"title": result.title,
24+
"authors": [author.name for author in result.authors],
25+
"summary": result.summary,
26+
"comment": str(result.comment),
27+
"journal_ref": str(result.journal_ref),
28+
"doi": str(result.doi),
29+
"primary_category": result.primary_category,
30+
"categories": result.categories,
31+
"links": [{"title": link.title, "href": link.href, "rel": link.rel} for link in result.links],
32+
"pdf_url": result.pdf_url,
33+
"download_time": datetime.now().isoformat()
34+
}
35+
36+
return metadata
37+
38+
39+
def download_arxiv_papers(topic, max_papers=200, save_dir="papers", sleep_interval=2):
40+
"""
41+
下载指定主题的ArXiv论文并保存结构化元数据
42+
43+
Args:
44+
topic (str): 要搜索的主题/查询
45+
max_papers (int): 要下载的最大论文数量
46+
save_dir (str): 保存论文的基本目录
47+
sleep_interval (float): 下载间隔时间(避免API限制)
48+
49+
Returns:
50+
int: 成功下载的论文数量
51+
"""
52+
# 配置日志
53+
topic_safe = topic.replace(' ', '_').replace('/', '_').replace('\\', '_')
54+
logging.basicConfig(
55+
level=logging.INFO,
56+
format='%(asctime)s - %(levelname)s - %(message)s',
57+
handlers=[
58+
logging.FileHandler(f"{save_dir}/{topic_safe}_download.log"),
59+
logging.StreamHandler()
60+
]
61+
)
62+
63+
logger = logging.getLogger(__name__)
64+
logger.info(f"开始下载主题: {topic}")
65+
66+
# 创建文件夹结构
67+
topic_dir = os.path.join(save_dir, f"topic_{topic_safe}")
68+
pdfs_dir = os.path.join(topic_dir, "pdfs")
69+
metadata_dir = os.path.join(topic_dir, "metadata")
70+
71+
os.makedirs(pdfs_dir, exist_ok=True)
72+
os.makedirs(metadata_dir, exist_ok=True)
73+
74+
logger.info(f"创建目录: {pdfs_dir}{metadata_dir}")
75+
76+
# 创建一个总体元数据文件,包含所有下载的论文信息
77+
all_metadata_file = os.path.join(topic_dir, f"{topic_safe}_all_metadata.json")
78+
all_metadata = []
79+
80+
# 配置搜索
81+
search = arxiv.Search(
82+
query=topic,
83+
max_results=max_papers,
84+
sort_by=arxiv.SortCriterion.Relevance
85+
)
86+
87+
client = arxiv.Client()
88+
89+
# 初始化计数器
90+
successful_downloads = 0
91+
failed_downloads = 0
92+
93+
# 下载论文
94+
try:
95+
results = list(client.results(search))
96+
total_results = len(results)
97+
logger.info(f"找到 {total_results} 篇关于主题 '{topic}' 的论文")
98+
99+
for i, result in enumerate(tqdm(results, desc=f"下载主题 '{topic}' 的论文")):
100+
try:
101+
# 获取论文ID并创建文件名
102+
paper_id = result.get_short_id()
103+
pdf_filename = f"{paper_id}.pdf"
104+
metadata_filename = f"{paper_id}.json"
105+
106+
# 处理元数据
107+
metadata = process_metadata(result)
108+
metadata_path = os.path.join(metadata_dir, metadata_filename)
109+
110+
# 保存单个论文元数据
111+
with open(metadata_path, 'w', encoding='utf-8') as f:
112+
json.dump(metadata, f, ensure_ascii=False, indent=2)
113+
114+
# 添加到总体元数据
115+
all_metadata.append(metadata)
116+
117+
# 保存总体元数据每10篇论文更新一次
118+
if (i + 1) % 10 == 0 or (i + 1) == total_results:
119+
with open(all_metadata_file, 'w', encoding='utf-8') as f:
120+
json.dump(all_metadata, f, ensure_ascii=False, indent=2)
121+
122+
# 下载PDF
123+
pdf_path = os.path.join(pdfs_dir, pdf_filename)
124+
result.download_pdf(dirpath=pdfs_dir, filename=pdf_filename)
125+
successful_downloads += 1
126+
127+
# 休眠以避免速率限制
128+
time.sleep(sleep_interval)
129+
130+
except Exception as e:
131+
logger.error(f"下载论文 {paper_id} 时出错: {str(e)}")
132+
failed_downloads += 1
133+
134+
# 每10篇论文记录一次进度
135+
if (i + 1) % 10 == 0:
136+
logger.info(f"进度: {i + 1}/{total_results} 篇论文已处理")
137+
time.sleep(0.5)
138+
except Exception as e:
139+
logger.error(f"搜索或下载过程中出错: {str(e)}")
140+
141+
# 记录最终统计信息
142+
logger.info(f"主题 '{topic}' 的下载已完成")
143+
logger.info(f"成功下载: {successful_downloads} 篇论文")
144+
logger.info(f"下载失败: {failed_downloads} 篇论文")
145+
146+
return successful_downloads
147+
148+
149+
def batch_download_topics(topics_list, max_papers_per_topic=200, base_dir="papers"):
150+
"""
151+
批量下载多个主题的论文
152+
153+
Args:
154+
topics_list (list): 主题列表
155+
max_papers_per_topic (int): 每个主题要下载的最大论文数量
156+
base_dir (str): 基本保存目录
157+
158+
Returns:
159+
dict: 每个主题的下载统计信息
160+
"""
161+
os.makedirs(base_dir, exist_ok=True)
162+
163+
results = {}
164+
total_start_time = time.time()
165+
166+
for i, topic in enumerate(topics_list):
167+
print(f"\n[{i + 1}/{len(topics_list)}] 开始下载主题: {topic}")
168+
169+
topic_start_time = time.time()
170+
papers_downloaded = download_arxiv_papers(
171+
topic=topic,
172+
max_papers=max_papers_per_topic,
173+
save_dir=base_dir,
174+
sleep_interval=3 # 为批量下载增加一点休眠时间
175+
)
176+
177+
topic_elapsed_time = time.time() - topic_start_time
178+
179+
results[topic] = {
180+
"papers_downloaded": papers_downloaded,
181+
"elapsed_time": f"{topic_elapsed_time:.2f} 秒"
182+
}
183+
184+
print(f"主题 '{topic}' 已完成: 下载 {papers_downloaded} 篇论文,用时 {topic_elapsed_time:.2f} 秒")
185+
186+
# 在主题之间添加额外休眠以减轻API负担
187+
if i < len(topics_list) - 1:
188+
rest_time = 10
189+
print(f"休息 {rest_time} 秒后继续下一个主题...")
190+
time.sleep(rest_time)
191+
192+
total_elapsed_time = time.time() - total_start_time
193+
print(f"\n批量下载已完成! 总用时: {total_elapsed_time:.2f} 秒")
194+
195+
# 保存批量下载的摘要
196+
summary_file = os.path.join(base_dir, "batch_download_summary.json")
197+
with open(summary_file, 'w', encoding='utf-8') as f:
198+
summary = {
199+
"total_topics": len(topics_list),
200+
"total_time": f"{total_elapsed_time:.2f} 秒",
201+
"completed_at": datetime.now().isoformat(),
202+
"topics_results": results
203+
}
204+
json.dump(summary, f, ensure_ascii=False, indent=2)
205+
206+
return results
207+
208+
209+
# 使用示例:
210+
if __name__ == "__main__":
211+
# 单个主题下载
212+
# download_arxiv_papers("Reasoning Large Language Models", max_papers=200)
213+
214+
# 多个主题批量下载
215+
topics = [
216+
"Reasoning Large Language Models",
217+
# "LLM Post-Training",
218+
# "Chain of Thought",
219+
]
220+
221+
batch_download_topics(topics, max_papers_per_topic=200)

examples/datasets/parse_papers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
3+
from magic_pdf.data.dataset import PymuDocDataset
4+
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
5+
from magic_pdf.config.enums import SupportedPdfParseMethod
6+
from tqdm import tqdm
7+
8+
# 要处理的目录列表
9+
directories = [
10+
"papers/topic_Chain_of_Thought/pdfs",
11+
"papers/topic_LLM_Post-Training/pdfs",
12+
"papers/topic_Reasoning_Large_Language_Models/pdfs",
13+
]
14+
15+
def process_pdf(pdf_file_path, output_dir):
16+
pdf_file_name = os.path.basename(pdf_file_path) # 获取 PDF 文件名
17+
name_without_suff = pdf_file_name.split(".")[0] # 去掉文件扩展名
18+
19+
# 准备环境
20+
local_image_dir = os.path.join(output_dir, "images") # 图片输出目录
21+
local_md_dir = output_dir # Markdown 输出目录
22+
image_dir = str(os.path.basename(local_image_dir)) # 图片目录名称
23+
24+
os.makedirs(local_image_dir, exist_ok=True) # 创建图片输出目录
25+
26+
# 创建文件写入对象
27+
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
28+
29+
# 读取 PDF 文件字节
30+
reader1 = FileBasedDataReader("")
31+
pdf_bytes = reader1.read(pdf_file_path) # 读取 PDF 文件内容
32+
33+
# 处理 PDF 文件
34+
# 创建数据集实例
35+
ds = PymuDocDataset(pdf_bytes)
36+
37+
# 推断 PDF 文件类型并进行相应处理
38+
if ds.classify() == SupportedPdfParseMethod.OCR:
39+
infer_result = ds.apply(doc_analyze, ocr=True) # 使用 OCR 进行解析
40+
pipe_result = infer_result.pipe_ocr_mode(image_writer) # 处理 OCR 模式结果
41+
else:
42+
infer_result = ds.apply(doc_analyze, ocr=False) # 使用文本模式进行解析
43+
pipe_result = infer_result.pipe_txt_mode(image_writer) # 处理文本模式结果
44+
45+
# 绘制结果并获取内容
46+
infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf")) # 绘制模型结果
47+
model_inference_result = infer_result.get_infer_res() # 获取模型推断结果
48+
pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf")) # 绘制布局结果
49+
pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf")) # 绘制跨度结果
50+
md_content = pipe_result.get_markdown(image_dir) # 获取 Markdown 内容
51+
pipe_result.dump_md(md_writer, f"{name_without_suff}.md", image_dir) # 导出 Markdown 文件
52+
content_list_content = pipe_result.get_content_list(image_dir) # 获取内容列表
53+
pipe_result.dump_content_list(md_writer, f"{name_without_suff}_content_list.json", image_dir) # 导出内容列表 JSON 文件
54+
middle_json_content = pipe_result.get_middle_json() # 获取中间 JSON 内容
55+
pipe_result.dump_middle_json(md_writer, f'{name_without_suff}_middle.json') # 导出中间 JSON 文件
56+
57+
# 处理每个目录
58+
for directory in directories:
59+
output_dir = os.path.join(directory, "output") # 输出目录
60+
os.makedirs(output_dir, exist_ok=True) # 创建输出目录
61+
62+
for file_name in tqdm(os.listdir(directory)):
63+
if file_name.endswith(".pdf"): # 检查文件是否为 PDF
64+
pdf_file_path = os.path.join(directory, file_name) # 获取 PDF 文件路径
65+
process_pdf(pdf_file_path, output_dir) # 处理 PDF 文件

0 commit comments

Comments
 (0)