Skip to content

Commit e7eb8c8

Browse files
Merge pull request #164 from gomate-community/pipeline
Pipeline
2 parents 926ed71 + 7d7f245 commit e7eb8c8

File tree

6 files changed

+294
-107
lines changed

6 files changed

+294
-107
lines changed

app_local_model.py

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,60 +10,114 @@
1010
import sys
1111

1212
sys.path.append(".")
13-
import os
1413
import shutil
1514
import time
1615
import gradio as gr
1716
import loguru
1817
import pandas as pd
1918
from datetime import datetime
2019
import pytz
21-
from trustrag.config.config_loader import ConfigLoader
22-
from trustrag.applications.rag import RagApplication, ApplicationConfig
2320
from trustrag.modules.reranker.bge_reranker import BgeRerankerConfig
2421
from trustrag.modules.retrieval.dense_retriever import DenseRetrieverConfig
22+
import os
23+
from trustrag.modules.citation.match_citation import MatchCitation
24+
from trustrag.modules.document.common_parser import CommonParser
25+
from trustrag.modules.generator.llm import Qwen3Chat
26+
from trustrag.modules.reranker.bge_reranker import BgeReranker
27+
from trustrag.modules.retrieval.dense_retriever import DenseRetriever
28+
from trustrag.modules.document.chunk import TextChunker
29+
from trustrag.modules.vector.embedding import SentenceTransformerEmbedding
30+
31+
32+
class ApplicationConfig():
33+
def __init__(self):
34+
self.retriever_config = None
35+
self.rerank_config = None
36+
37+
38+
class RagApplication():
39+
def __init__(self, config):
40+
self.config = config
41+
self.parser = CommonParser()
42+
self.embedding_generator = SentenceTransformerEmbedding(self.config.retriever_config.model_name_or_path)
43+
self.retriever = DenseRetriever(self.config.retriever_config, self.embedding_generator)
44+
self.reranker = BgeReranker(self.config.rerank_config)
45+
self.llm = Qwen3Chat(self.config.llm_model_path)
46+
self.mc = MatchCitation()
47+
self.tc = TextChunker()
48+
49+
def init_vector_store(self):
50+
"""
51+
52+
"""
53+
print("init_vector_store ... ")
54+
all_paragraphs = []
55+
all_chunks = []
56+
for filename in os.listdir(self.config.docs_path):
57+
file_path = os.path.join(self.config.docs_path, filename)
58+
try:
59+
paragraphs = self.parser.parse(file_path)
60+
all_paragraphs.append(paragraphs)
61+
except:
62+
pass
63+
print("chunking for paragraphs")
64+
for paragraphs in all_paragraphs:
65+
# 确保paragraphs是list,并处理其中的元素
66+
if isinstance(paragraphs, list) and paragraphs:
67+
if isinstance(paragraphs[0], dict):
68+
# list[dict] -> list[str]
69+
text_list = [' '.join(str(value) for value in item.values()) for item in paragraphs]
70+
else:
71+
# 已经是list[str]
72+
text_list = [str(item) for item in paragraphs]
73+
else:
74+
# 处理其他情况
75+
text_list = [str(paragraphs)] if paragraphs else []
76+
77+
chunks = self.tc.get_chunks(text_list, 256)
78+
all_chunks.extend(chunks)
79+
80+
self.retriever.build_from_texts(all_chunks)
81+
print("init_vector_store done! ")
82+
self.retriever.save_index(self.config.retriever_config.index_path)
83+
84+
def load_vector_store(self):
85+
self.retriever.load_index(self.config.retriever_config.index_path)
86+
87+
def add_document(self, file_path):
88+
chunks = self.parser.parse(file_path)
89+
for chunk in chunks:
90+
self.retriever.add_text(chunk)
91+
print("add_document done!")
92+
93+
def chat(self, question: str = '', top_k: int = 5):
94+
contents = self.retriever.retrieve(query=question, top_k=top_k)
95+
contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
96+
content = '\n'.join([content['text'] for content in contents])
97+
result, history = self.llm.chat(question, [], content)
98+
return result, history, contents, question
2599

26100

27101
# ========================== Config Start====================
28-
# # 创建全局配置实例
29-
# config = ConfigLoader(config_path="config_local.json")
30-
# app_config = ApplicationConfig()
31-
#
32-
# llm_model = config.get_config('models.llm')
33-
# embedding_model = config.get_config('models.embedding')
34-
# reranker_model = config.get_config('models.reranker')
35-
#
36-
# # 加载配置
37-
# app_config.docs_path = config.get_config('paths.docs')
38-
# retriever_config = DenseRetrieverConfig(
39-
# model_name_or_path=embedding_model["path"],
40-
# dim=1024,
41-
# index_path=config.get_config('index')
42-
# )
43-
# rerank_config = BgeRerankerConfig(
44-
# model_name_or_path=reranker_model["path"],
45-
# )
46102
app_config = ApplicationConfig()
47-
app_config.docs_path = r"G:\Projects\TrustRAG\data\docs"
48-
app_config.llm_model_path = r"G:\pretrained_models\llm\glm-4-9b-chat"
103+
app_config.docs_path = r"/data/users/searchgpt/yq/TrustRAG/data/docs"
104+
app_config.llm_model_path = r"/data/users/searchgpt/pretrained_models/Qwen3-4B"
49105
retriever_config = DenseRetrieverConfig(
50-
model_name_or_path=r"G:\pretrained_models\mteb\bge-large-zh-v1.5",
106+
model_name_or_path=r"/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
51107
dim=1024,
52-
index_path=r'G:\Projects\TrustRAG\examples\retrievers\dense_cache'
108+
index_path=r'/data/users/searchgpt/yq/TrustRAG/examples/retrievers/dense_cache'
53109
)
54110
rerank_config = BgeRerankerConfig(
55-
model_name_or_path=r"G:\pretrained_models\mteb\bge-reranker-large"
111+
model_name_or_path=r"/data/users/searchgpt/pretrained_models/bge-reranker-large"
56112
)
57113

58114
app_config.retriever_config = retriever_config
59115
app_config.rerank_config = rerank_config
60116
application = RagApplication(app_config)
61117
application.init_vector_store()
62118

63-
64119
# ========================== Config End====================
65120

66-
67121
# 创建北京时区的变量
68122
beijing_tz = pytz.timezone("Asia/Shanghai")
69123
IGNORE_FILE_LIST = [".DS_Store"]

examples/generator/vllm_curl.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
curl http://localhost:8002/v1/chat/completions -H "Content-Type: application/json" -d '{
2+
"model": "Qwen3-32B",
3+
"messages": [
4+
{"role": "user", "content": "Give me a short introduction to large language models."}
5+
],
6+
"temperature": 0.6,
7+
"top_p": 0.95,
8+
"top_k": 20
9+
}'

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ xgboost
5050
bm25s
5151
jieba
5252
accelerate
53-
FlagEmbedding
5453
chardet
5554
openpyxl
5655
protobuf

trustrag/applications/rag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from trustrag.modules.reranker.bge_reranker import BgeReranker
1616
from trustrag.modules.retrieval.dense_retriever import DenseRetriever
1717
from trustrag.modules.document.chunk import TextChunker
18-
from trustrag.modules.vector.embedding import FlagModelEmbedding
18+
from trustrag.modules.vector.embedding import SentenceTransformerEmbedding
1919
class ApplicationConfig():
2020
def __init__(self):
2121
self.retriever_config = None
@@ -26,7 +26,7 @@ class RagApplication():
2626
def __init__(self, config):
2727
self.config = config
2828
self.parser = CommonParser()
29-
self.embedding_generator = FlagModelEmbedding(self.config.retriever_config.model_name_or_path)
29+
self.embedding_generator = SentenceTransformerEmbedding(self.config.retriever_config.model_name_or_path)
3030
self.retriever = DenseRetriever(self.config.retriever_config,self.embedding_generator)
3131
self.reranker = BgeReranker(self.config.rerank_config)
3232
self.llm = GLM4Chat(self.config.llm_model_path)

trustrag/modules/generator/llm.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,187 @@ def load_model(self):
180180
print("load model success")
181181

182182

183+
# !/usr/bin/env python
184+
# -*- coding:utf-8 _*-
185+
"""
186+
@author:quincy qiang
187+
@license: Apache Licence
188+
@file: llm.py
189+
@time: 2024/05/16
190+
@contact: yanqiangmiffy@gamil.com
191+
@software: PyCharm
192+
@description: coding..
193+
"""
194+
import os
195+
from typing import Dict, List, Any
196+
197+
import torch
198+
from openai import OpenAI
199+
from transformers import AutoTokenizer, AutoModelForCausalLM
200+
from trustrag.modules.prompt.templates import SYSTEM_PROMPT, CHAT_PROMPT_TEMPLATES
201+
202+
203+
class BaseModel:
204+
def __init__(self, path: str = '') -> None:
205+
self.path = path
206+
207+
def chat(self, prompt: str, history: List[dict], content: str) -> str:
208+
pass
209+
210+
def load_model(self):
211+
pass
212+
213+
214+
class OpenAIChat(BaseModel):
215+
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
216+
super().__init__(path)
217+
self.model = model
218+
219+
def chat(self, prompt: str, history: List[dict], content: str) -> str:
220+
client = OpenAI()
221+
client.api_key = os.getenv("OPENAI_API_KEY")
222+
client.base_url = os.getenv("OPENAI_BASE_URL")
223+
history.append({'role': 'user',
224+
'content': CHAT_PROMPT_TEMPLATES['RAG_PROMPT_TEMPALTE'].format(question=prompt,
225+
context=content)})
226+
response = client.chat.completions.create(
227+
model=self.model,
228+
messages=history,
229+
max_tokens=150,
230+
temperature=0.1
231+
)
232+
return response.choices[0].message.content
233+
234+
235+
class InternLMChat(BaseModel):
236+
def __init__(self, path: str = '') -> None:
237+
super().__init__(path)
238+
self.load_model()
239+
240+
def chat(self, prompt: str, history: List = [], content: str = '') -> str:
241+
prompt = CHAT_PROMPT_TEMPLATES['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
242+
response, history = self.model.chat(self.tokenizer, prompt, history)
243+
return response
244+
245+
def load_model(self):
246+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
247+
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16,
248+
trust_remote_code=True).cuda()
249+
250+
251+
class GLM3Chat(BaseModel):
252+
def __init__(self, path: str = '') -> None:
253+
super().__init__(path)
254+
self.load_model()
255+
256+
def chat(self, prompt: str, history=None, content: str = '', llm_only: bool = False) -> tuple[Any, Any]:
257+
if history is None:
258+
history = []
259+
if llm_only:
260+
prompt = prompt
261+
else:
262+
prompt = CHAT_PROMPT_TEMPLATES['GLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
263+
response, history = self.model.chat(self.tokenizer, prompt, history, max_length=32000, num_beams=1,
264+
do_sample=True, top_p=0.8, temperature=0.2)
265+
return response, history
266+
267+
def load_model(self):
268+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
269+
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16,
270+
trust_remote_code=True).cuda()
271+
272+
273+
class GLM4Chat(BaseModel):
274+
def __init__(self, path: str = '') -> None:
275+
super().__init__(path)
276+
self.load_model()
277+
278+
def chat(self, prompt: str, history=None, content: str = '', llm_only: bool = False) -> tuple[Any, Any]:
279+
if llm_only:
280+
prompt = prompt
281+
else:
282+
prompt = CHAT_PROMPT_TEMPLATES['GLM_PROMPT_TEMPALTE'].format(system_prompt=SYSTEM_PROMPT, question=prompt,
283+
context=content)
284+
prompt = prompt.encode("utf-8", 'ignore').decode('utf-8', 'ignore')
285+
print(prompt)
286+
287+
inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}],
288+
add_generation_prompt=True,
289+
tokenize=True,
290+
return_tensors="pt",
291+
return_dict=True
292+
)
293+
294+
inputs = inputs.to('cuda')
295+
gen_kwargs = {"max_length": 5120, "do_sample": False, "top_k": 1}
296+
with torch.no_grad():
297+
outputs = self.model.generate(**inputs, **gen_kwargs)
298+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
299+
output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
300+
response, history = output, []
301+
return response, history
302+
303+
def load_model(self):
304+
305+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
306+
self.model = AutoModelForCausalLM.from_pretrained(
307+
self.path,
308+
torch_dtype=torch.bfloat16,
309+
low_cpu_mem_usage=True,
310+
trust_remote_code=True
311+
).cuda().eval()
312+
313+
314+
class Qwen3Chat(BaseModel):
315+
def __init__(self, path: str = '') -> None:
316+
super().__init__(path)
317+
self.load_model()
318+
self.device = 'cuda'
319+
320+
def chat(self, prompt: str, history: List = [], content: str = '', llm_only: bool = False,
321+
enable_thinking: bool = True) -> tuple[Any, Any]:
322+
if llm_only:
323+
prompt = prompt
324+
else:
325+
# 使用适当的prompt模板,可以根据需要调整
326+
prompt = CHAT_PROMPT_TEMPLATES.get('DF_QWEN_PROMPT_TEMPLATE2', '{question}\n\n上下文:{context}').format(
327+
question=prompt, context=content)
328+
329+
messages = [
330+
{"role": "user", "content": prompt}
331+
]
332+
333+
text = self.tokenizer.apply_chat_template(
334+
messages,
335+
tokenize=False,
336+
add_generation_prompt=True,
337+
enable_thinking=enable_thinking # 支持thinking模式
338+
)
339+
340+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
341+
342+
# 生成文本,支持更大的token数量
343+
generated_ids = self.model.generate(
344+
**model_inputs,
345+
max_new_tokens=32768, # 支持更大的生成长度
346+
do_sample=False,
347+
top_k=10
348+
)
349+
350+
# 提取生成的部分
351+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
352+
response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
353+
354+
return response, history
355+
356+
def load_model(self):
357+
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
358+
self.model = AutoModelForCausalLM.from_pretrained(
359+
self.path,
360+
torch_dtype="auto", # 使用auto自动选择最佳数据类型
361+
device_map="auto", # 自动设备映射
362+
trust_remote_code=True
363+
)
364+
print("Qwen3 model loaded successfully")
365+
366+

0 commit comments

Comments
 (0)