-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag_app_v5.py
228 lines (178 loc) · 8.07 KB
/
rag_app_v5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from langchain_community.document_loaders import (
PDFPlumberLoader,
TextLoader,
UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader,
UnstructuredExcelLoader,
CSVLoader,
UnstructuredMarkdownLoader,
UnstructuredXMLLoader,
UnstructuredHTMLLoader,
)
from typing import List, Dict
from TextSplitter import TextSplitter, chunk_regex
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import os
import dashscope
from http import HTTPStatus
import chromadb
import uuid
import shutil
from rank_bm25 import BM25Okapi
import jieba
from FlagEmbedding import FlagReranker # 用于对嵌入结果进行重新排序的工具类
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QWEN_MODEL = "qwen-turbo"
QWEN_API_KEY = "your_api_key"
DOCUMENT_LOADER_MAPPING = {
".pdf": (PDFPlumberLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".xlsx": (UnstructuredExcelLoader, {}),
".csv": (CSVLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".xml": (UnstructuredXMLLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
}
def load_document(file_path):
ext = os.path.splitext(file_path)[1]
loader_class, loader_args = DOCUMENT_LOADER_MAPPING.get(ext, (None, None))
if loader_class:
loader = loader_class(file_path, **loader_args)
documents = loader.load()
content = "\n".join([doc.page_content for doc in documents])
return content
print(f"不支持的文档类型: '{ext}'")
return ""
def load_embedding_model(model_path='rag_app/bge-small-zh-v1.5'):
print("加载Embedding模型中")
embedding_model = SentenceTransformer(os.path.abspath(model_path))
print(f"bge-small-zh-v1.5模型最大输入长度: {embedding_model.max_seq_length}\n")
return embedding_model
def reranking(query, chunks, top_k=3):
# 初始化重排序模型,使用BAAI/bge-reranker-v2-m3
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
# 构造输入对,每个 query 与 chunk 形成一对
input_pairs = [[query, chunk] for chunk in chunks]
# 计算每个 chunk 与 query 的语义相似性得分
scores = reranker.compute_score(input_pairs, normalize=True)
print("文档块重排序得分:", scores)
# 对得分进行排序并获取排名前 top_k 的 chunks
sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
reranking_chunks = [chunks[i] for i in sorted_indices[:top_k]]
# 打印前三个 score 对应的文档块
for i in range(top_k):
print(f"重排序文档块{i+1}: 相似度得分:{scores[sorted_indices[i]]},文档块信息:{reranking_chunks[i]}\n")
return reranking_chunks
def indexing_process(folder_path: str, embedding_model, collection):
all_chunks: List[Dict[str, str]] = []
all_ids: List[str] = []
splitter = TextSplitter(chunk_regex)
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
document_text = load_document(file_path)
if document_text:
print(f"文档 {filename} 的总字符数: {len(document_text)}")
chunks_with_metadata = splitter.split_with_metadata(document_text)
print(f"文档 {filename} 分割的文本Chunk数量: {len(chunks_with_metadata)}")
for chunk in chunks_with_metadata:
all_chunks.append(chunk)
all_ids.append(str(uuid.uuid4()))
# 添加文件名到元数据
chunk['metadata']['filename'] = filename
# 生成嵌入向量
embeddings = [embedding_model.encode(chunk['content'], normalize_embeddings=True).tolist() for chunk in all_chunks]
# 准备存储到向量数据库的数据
documents = [chunk['content'] for chunk in all_chunks]
metadatas = [chunk['metadata'] for chunk in all_chunks]
# 将数据添加到集合中
collection.add(
ids=all_ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas
)
print("嵌入生成完成,向量数据库存储完成.")
print("索引过程完成.")
print("********************************************************")
def retrieval_process(query, collection, embedding_model=None, top_k=6):
query_embedding = embedding_model.encode(query, normalize_embeddings=True).tolist()
vector_results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
all_docs = collection.get()['documents']
tokenized_corpus = [list(jieba.cut(doc)) for doc in all_docs]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = list(jieba.cut(query))
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_k_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:top_k]
bm25_chunks = [all_docs[i] for i in bm25_top_k_indices]
print(f"查询语句: {query}")
print(f"向量检索最相似的前 {top_k} 个文本块:")
vector_chunks = []
for rank, (doc_id, doc) in enumerate(zip(vector_results['ids'][0], vector_results['documents'][0])):
print(f"向量检索排名: {rank + 1}")
print(f"文本块ID: {doc_id}")
print(f"文本块信息:\n{doc}\n")
vector_chunks.append(doc)
print(f"BM25 检索最相似的前 {top_k} 个文本块:")
for rank, doc in enumerate(bm25_chunks):
print(f"BM25 检索排名: {rank + 1}")
print(f"文档内容:\n{doc}\n")
# 使用重排序模型对检索结果进行重新排序,输出重排序后的前top_k文档块
reranking_chunks = reranking(query,vector_chunks + bm25_chunks, top_k)
print("检索过程完成.")
print("********************************************************")
# 返回重排序后的前top_k个文档块
return reranking_chunks
def generate_process(query, chunks):
llm_model = QWEN_MODEL
dashscope.api_key = QWEN_API_KEY
context = ""
for i, chunk in enumerate(chunks):
context += f"参考文档{i+1}: \n{chunk}\n\n"
prompt = f"根据参考文档回答问题:{query}\n\n{context}"
print(prompt+"\n")
messages = [{'role': 'user', 'content': prompt}]
try:
responses = dashscope.Generation.call(
model = llm_model,
messages=messages,
result_format='message',
stream=True,
incremental_output=True
)
generated_response = ""
print("生成过程开始:")
for response in responses:
if response.status_code == HTTPStatus.OK:
content = response.output.choices[0]['message']['content']
generated_response += content
print(content, end='')
else:
print(f"请求失败: {response.status_code} - {response.message}")
return None
print("\n生成过程完成.")
print("********************************************************")
return generated_response
except Exception as e:
print(f"大模型生成过程中发生错误: {e}")
return None
def main():
print("RAG过程开始.")
chroma_db_path = os.path.abspath("rag_app/chroma_db")
if os.path.exists(chroma_db_path):
shutil.rmtree(chroma_db_path)
client = chromadb.PersistentClient(path=os.path.abspath(chroma_db_path))
collection = client.get_or_create_collection(name="documents")
embedding_model = load_embedding_model()
indexing_process('rag_app/data_lesson6', embedding_model, collection)
query = "下面报告中涉及了哪几个行业的案例以及总结各自面临的挑战?"
retrieval_chunks = retrieval_process(query, collection, embedding_model)
generate_process(query, retrieval_chunks)
print("RAG过程结束.")
if __name__ == "__main__":
main()