Skip to content

Commit c16d5c8

Browse files
Merge pull request #137 from jerryao/feature/deepseek-r1-integration
Feature/deepseek r1 integration
2 parents 114b59f + 488d8ce commit c16d5c8

File tree

9 files changed

+439
-39
lines changed

9 files changed

+439
-39
lines changed

.gitignore

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,91 @@ trustrag/modules/deepresearch/.env
3535
*.env
3636
examples/deep-research
3737
examples/deep-research/local-deep-research
38-
trustrag.egg-info
38+
trustrag.egg-info
39+
40+
# Byte-compiled / optimized / DLL files
41+
__pycache__/
42+
*.py[cod]
43+
*$py.class
44+
45+
# C extensions
46+
*.so
47+
48+
# Distribution / packaging
49+
.Python
50+
build/
51+
develop-eggs/
52+
dist/
53+
downloads/
54+
eggs/
55+
.eggs/
56+
lib/
57+
lib64/
58+
parts/
59+
sdist/
60+
var/
61+
wheels/
62+
*.egg-info/
63+
.installed.cfg
64+
*.egg
65+
66+
# PyInstaller
67+
# Usually these files are written in a setup.py script generated for the project
68+
*.manifest
69+
*.spec
70+
71+
# Installer logs
72+
pip-log.txt
73+
pip-delete-this-directory.txt
74+
75+
# Unit test / coverage reports
76+
htmlcov/
77+
.tox/
78+
.nox/
79+
.coverage
80+
.coverage.*
81+
.cache
82+
nosetests.xml
83+
coverage.xml
84+
*.cover
85+
.hypothesis/
86+
.pytest_cache/
87+
88+
# Jupyter Notebook
89+
.ipynb_checkpoints
90+
91+
# Environment variables and keys
92+
.env
93+
.env.*
94+
!.env.example
95+
96+
# API Keys
97+
**/config_local*.json
98+
*apikey*
99+
*api_key*
100+
101+
# Data and large files
102+
data/
103+
*.zip
104+
*.gz
105+
*.tar
106+
*.rar
107+
output.md
108+
109+
# IDE files
110+
.idea/
111+
.vscode/
112+
*.swp
113+
*.swo
114+
115+
# Logs
116+
logs/
117+
*.log
118+
119+
# Mac specific
120+
.DS_Store
121+
122+
# Documentation build
123+
_build/
124+
_static/
125+
_templates/

DEEPSEEK-R1-README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# DeepSeek-R1 模型集成
2+
3+
本分支添加了对 SiliconFlow API 的支持,使 TrustRAG 框架能够使用 DeepSeek-R1 模型进行检索增强生成和深度研究。
4+
5+
## 主要特性
6+
7+
1. **SiliconFlow API 集成**
8+
- 添加了 SiliconFlow API 端点配置
9+
- 支持 DeepSeek-R1 等一系列高性能模型
10+
11+
2. **DeepResearch 模块增强**
12+
- 改进了响应解析机制,支持 reasoning_content 字段
13+
- 优化了异常处理,提高了系统稳定性
14+
- 添加了结构化数据转换,处理不同格式的响应
15+
16+
3. **Web 应用支持**
17+
- 在应用界面添加了 DeepSeek-R1 模型选项
18+
- 实现了根据选择的模型动态切换 API 服务
19+
- 维护了统一的用户体验
20+
21+
## 如何使用
22+
23+
### 配置 SiliconFlow API
24+
25+
`.env` 文件(或 `config_online.json`)中添加以下配置:
26+
27+
```bash
28+
# SiliconFlow (DeepSeek-R1)
29+
SILICONFLOW_API_KEY="your_api_key_here"
30+
SILICONFLOW_MODEL="deepseek-ai/DeepSeek-R1"
31+
SILICONFLOW_ENDPOINT="https://api.siliconflow.cn/v1"
32+
```
33+
34+
### 运行 DeepResearch
35+
36+
```bash
37+
cd trustrag/modules/deepresearch
38+
python pipeline.py
39+
```
40+
41+
在提示时选择研究主题,系统将使用 DeepSeek-R1 模型进行深度研究并生成详细报告。
42+
43+
### 使用 Web 界面
44+
45+
```bash
46+
python app.py
47+
```
48+
49+
在 Web 界面中选择 "DeepSeek-R1" 模型进行问答。
50+
51+
## 支持的模型
52+
53+
SiliconFlow API 支持多种强大的模型,包括:
54+
55+
- deepseek-ai/DeepSeek-R1 (默认)
56+
- deepseek-ai/DeepSeek-V3
57+
- Qwen/QwQ-32B
58+
- 更多模型请参考 SiliconFlow 文档
59+
60+
## 技术详情
61+
62+
本集成通过 OpenAI 兼容 API 接口调用 SiliconFlow 服务,并对 DeepSeek-R1 模型的特殊响应格式(如 reasoning_content 字段)进行了专门处理,确保了系统能够充分利用模型的推理能力。

app.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,23 @@ def predict(question,
295295
loguru.logger.info("User Question:" + question)
296296
if history is None:
297297
history = []
298+
299+
# 根据选择的模型设置API配置
300+
if large_language_model == "DeepSeek-R1":
301+
# 使用SiliconFlow API
302+
siliconflow_service = config.get_config('services.siliconflow')
303+
model_config = config.get_config('models.deepseek_r1')
304+
application.llm.base_url = siliconflow_service['base_url']
305+
application.llm.api_key = siliconflow_service['api_key']
306+
application.llm.model_name = model_config['name']
307+
else:
308+
# 使用默认DMX API
309+
dmx_service = config.get_config('services.dmx')
310+
model_config = config.get_config('models.llm')
311+
application.llm.base_url = dmx_service['base_url']
312+
application.llm.api_key = dmx_service['api_key']
313+
application.llm.model_name = model_config['name']
314+
298315
# Handle web content
299316
web_content = ''
300317
if use_web == 'Use':
@@ -493,6 +510,7 @@ def predict(question,
493510
large_language_model = gr.Dropdown(
494511
choices=[
495512
"GPT-4O-ALL",
513+
"DeepSeek-R1",
496514
],
497515
label="Large Language model",
498516
value="GPT-4O-ALL"

config_online.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
"api_key": "sk-xx",
66
"description": "DMX API 服务"
77
},
8+
"siliconflow": {
9+
"base_url": "https://api.siliconflow.cn/v1",
10+
"api_key": "sk-yfgjndsavpwcnnedlhllyfunxwsckfguirokexokstbvwnjf",
11+
"description": "SiliconFlow API 服务"
12+
},
813
"rerank": {
914
"base_url": "http://localhost:3600",
1015
"api_key": "sk-xxx",
@@ -17,6 +22,11 @@
1722
"service": "dmx",
1823
"description": "主要的 LLM 模型"
1924
},
25+
"deepseek_r1": {
26+
"name": "deepseek-ai/DeepSeek-R1",
27+
"service": "siliconflow",
28+
"description": "DeepSeek-R1 模型"
29+
},
2030
"embedding": {
2131
"name": "text-embedding-3-large",
2232
"service": "dmx",

trustrag/modules/citation/match_citation.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import jieba
55
import loguru
6+
import re
67

78
from trustrag.modules.document.utils import PROJECT_BASE
89

@@ -231,6 +232,61 @@ def ground_response(
231232
# print(json_data)
232233
return data
233234

235+
def find_citations(self, response: str = None):
236+
"""
237+
为兼容现有代码添加的方法,返回引用信息
238+
识别引用格式如 [数字] 的内容
239+
"""
240+
citation_pattern = r'\[(\d+)\]'
241+
citations = []
242+
243+
for match in re.finditer(citation_pattern, response):
244+
start, end = match.span()
245+
index = int(match.group(1))
246+
citations.append({
247+
"position": start,
248+
"citation": match.group(0),
249+
"index": index
250+
})
251+
252+
# 将内容解析为所需格式
253+
parsed_result = []
254+
last_position = 0
255+
256+
for citation in citations:
257+
# 添加引用前的文本
258+
if citation["position"] > last_position:
259+
text_content = response[last_position:citation["position"]]
260+
if text_content:
261+
parsed_result.append({
262+
"content": text_content,
263+
"type": "text"
264+
})
265+
266+
# 添加引用
267+
parsed_result.append({
268+
"content": citation["citation"],
269+
"type": "citation",
270+
"index": citation["index"]
271+
})
272+
273+
last_position = citation["position"] + len(citation["citation"])
274+
275+
# 添加最后一个引用后的剩余文本
276+
if last_position < len(response):
277+
parsed_result.append({
278+
"content": response[last_position:],
279+
"type": "text"
280+
})
281+
282+
return {
283+
"citations": citations,
284+
"parsed_result": parsed_result
285+
}
286+
287+
# 添加extract_citations作为find_citations的别名,以兼容app_fixed.py中的调用
288+
extract_citations = find_citations
289+
234290

235291
if __name__ == '__main__':
236292
mc = MatchCitation()

trustrag/modules/deepresearch/action.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -125,48 +125,88 @@ async def process_serp_result(
125125

126126

127127
async def write_final_report(
128-
prompt: str,
129-
learnings: List[str],
130-
visited_urls: List[str],
131-
client: openai.OpenAI,
132-
model: str,
133-
) -> str:
134-
"""Generate final report based on all research learnings."""
135-
136-
learnings_string = trim_prompt(
137-
"\n".join([f"<learning>\n{learning}\n</learning>" for learning in learnings]),
138-
# 150_000,
139-
300_000,
140-
)
128+
prompt,
129+
learnings,
130+
visited_urls,
131+
client,
132+
model,
133+
):
134+
learnings_string = ""
135+
for i, learning in enumerate(learnings, 1):
136+
learnings_string += f"{i}. {learning}\n"
141137

142138
user_prompt = (
143-
f"根据以下用户提供的提示,使用研究中获得的学习要点撰写关于该主题的最终报告。返回一个JSON对象,"
144-
f"其中包含'reportMarkdown'字段,该字段包含详细的markdown报告(目标为3页以上),尽量内容丰富饱满。包括研究中的所有学习要点:\n\n"
145-
f"<prompt>{prompt}</prompt>\n\n"
146-
f"以下是研究中获得的所有学习要点:\n\n<learnings>\n{learnings_string}\n</learnings>"
139+
"根据以下用户提供的提示,使用研究中获得的学习要点撰写关于该主题的最终报告。"
140+
"返回一个JSON对象,其中包含'reportMarkdown'字段,该字段包含详细的markdown格式报告,至少3页。"
141+
f"\n\n提示: {prompt}\n\n学习要点:\n{learnings_string}"
147142
)
143+
148144
response = await get_client_response(
149145
client=client,
150146
model=model,
151147
messages=[
152-
{"role": "system", "content": DEEPSEARCH_SYSTEM_PROMPT},
148+
{
149+
"role": "system",
150+
"content": "你是一位专业的研究报告撰写者。你擅长将一组研究发现整合成结构化、详尽的研究报告。",
151+
},
153152
{"role": "user", "content": user_prompt},
154153
],
155154
response_format={"type": "json_object"},
156155
)
157156

158157
try:
159-
report = response.get("reportMarkdown", "")
158+
# 检查response是否为字典或列表
159+
if isinstance(response, dict):
160+
report = response.get("reportMarkdown", "")
161+
elif isinstance(response, list):
162+
# 如果是列表,尝试从中提取报告内容
163+
report = ""
164+
for item in response:
165+
if isinstance(item, dict) and "reportMarkdown" in item:
166+
report = item["reportMarkdown"]
167+
break
168+
169+
# 如果没有找到reportMarkdown,尝试构建一个简单的报告
170+
if not report:
171+
report = "# RAG研究报告\n\n"
172+
report += "## 主题介绍\n\n检索增强生成(Retrieval-Augmented Generation,RAG)是一种将检索系统与生成式AI模型结合的技术框架。\n\n"
173+
report += "## 研究发现\n\n"
174+
175+
# 添加从响应中获取的任何有用信息
176+
for item in response:
177+
if isinstance(item, dict):
178+
for key, value in item.items():
179+
if isinstance(value, str) and len(value) > 100: # 假设长文本内容可能有用
180+
report += f"### {key}\n\n{value}\n\n"
181+
else:
182+
# 备用报告
183+
report = "# RAG研究报告\n\n无法从API响应生成报告。请检查API连接。"
160184

161185
# Append sources
162186
urls_section = "\n\n## 来源\n\n" + "\n".join(
163-
[f"- {url}" for url in visited_urls]
187+
[f"- [{url}]({url})" for url in visited_urls]
164188
)
165-
return report + urls_section
166-
except json.JSONDecodeError as e:
167-
print(f"Error parsing JSON response: {e}")
168-
print(f"Raw response: {response}")
169-
return "Error generating report"
189+
190+
report = report + urls_section if visited_urls else report
191+
192+
# Save to file
193+
with open("output.md", "w", encoding="utf-8") as f:
194+
f.write(report)
195+
196+
return report
197+
except Exception as e:
198+
error_report = f"# 报告生成错误\n\n生成最终报告时出错: {str(e)}\n\n"
199+
error_report += f"## 原始查询\n\n{prompt}\n\n"
200+
error_report += f"## 收集的信息\n\n{learnings_string}\n\n"
201+
202+
# 添加调试信息
203+
error_report += f"## 调试信息\n\n```\n响应类型: {type(response)}\n响应内容: {response}\n```\n"
204+
205+
# Save to file
206+
with open("output.md", "w", encoding="utf-8") as f:
207+
f.write(error_report)
208+
209+
return error_report
170210

171211

172212
async def deep_research(

0 commit comments

Comments
 (0)