|
10 | 10 | import sys |
11 | 11 |
|
12 | 12 | sys.path.append(".") |
13 | | -import os |
14 | 13 | import shutil |
15 | 14 | import time |
16 | 15 | import gradio as gr |
17 | 16 | import loguru |
18 | 17 | import pandas as pd |
19 | 18 | from datetime import datetime |
20 | 19 | import pytz |
21 | | -from trustrag.config.config_loader import ConfigLoader |
22 | | -from trustrag.applications.rag import RagApplication, ApplicationConfig |
23 | 20 | from trustrag.modules.reranker.bge_reranker import BgeRerankerConfig |
24 | 21 | from trustrag.modules.retrieval.dense_retriever import DenseRetrieverConfig |
| 22 | +import os |
| 23 | +from trustrag.modules.citation.match_citation import MatchCitation |
| 24 | +from trustrag.modules.document.common_parser import CommonParser |
| 25 | +from trustrag.modules.generator.llm import Qwen3Chat |
| 26 | +from trustrag.modules.reranker.bge_reranker import BgeReranker |
| 27 | +from trustrag.modules.retrieval.dense_retriever import DenseRetriever |
| 28 | +from trustrag.modules.document.chunk import TextChunker |
| 29 | +from trustrag.modules.vector.embedding import SentenceTransformerEmbedding |
| 30 | + |
| 31 | + |
| 32 | +class ApplicationConfig(): |
| 33 | + def __init__(self): |
| 34 | + self.retriever_config = None |
| 35 | + self.rerank_config = None |
| 36 | + |
| 37 | + |
| 38 | +class RagApplication(): |
| 39 | + def __init__(self, config): |
| 40 | + self.config = config |
| 41 | + self.parser = CommonParser() |
| 42 | + self.embedding_generator = SentenceTransformerEmbedding(self.config.retriever_config.model_name_or_path) |
| 43 | + self.retriever = DenseRetriever(self.config.retriever_config, self.embedding_generator) |
| 44 | + self.reranker = BgeReranker(self.config.rerank_config) |
| 45 | + self.llm = Qwen3Chat(self.config.llm_model_path) |
| 46 | + self.mc = MatchCitation() |
| 47 | + self.tc = TextChunker() |
| 48 | + |
| 49 | + def init_vector_store(self): |
| 50 | + """ |
| 51 | +
|
| 52 | + """ |
| 53 | + print("init_vector_store ... ") |
| 54 | + all_paragraphs = [] |
| 55 | + all_chunks = [] |
| 56 | + for filename in os.listdir(self.config.docs_path): |
| 57 | + file_path = os.path.join(self.config.docs_path, filename) |
| 58 | + try: |
| 59 | + paragraphs = self.parser.parse(file_path) |
| 60 | + all_paragraphs.append(paragraphs) |
| 61 | + except: |
| 62 | + pass |
| 63 | + print("chunking for paragraphs") |
| 64 | + for paragraphs in all_paragraphs: |
| 65 | + # 确保paragraphs是list,并处理其中的元素 |
| 66 | + if isinstance(paragraphs, list) and paragraphs: |
| 67 | + if isinstance(paragraphs[0], dict): |
| 68 | + # list[dict] -> list[str] |
| 69 | + text_list = [' '.join(str(value) for value in item.values()) for item in paragraphs] |
| 70 | + else: |
| 71 | + # 已经是list[str] |
| 72 | + text_list = [str(item) for item in paragraphs] |
| 73 | + else: |
| 74 | + # 处理其他情况 |
| 75 | + text_list = [str(paragraphs)] if paragraphs else [] |
| 76 | + |
| 77 | + chunks = self.tc.get_chunks(text_list, 256) |
| 78 | + all_chunks.extend(chunks) |
| 79 | + |
| 80 | + self.retriever.build_from_texts(all_chunks) |
| 81 | + print("init_vector_store done! ") |
| 82 | + self.retriever.save_index(self.config.retriever_config.index_path) |
| 83 | + |
| 84 | + def load_vector_store(self): |
| 85 | + self.retriever.load_index(self.config.retriever_config.index_path) |
| 86 | + |
| 87 | + def add_document(self, file_path): |
| 88 | + chunks = self.parser.parse(file_path) |
| 89 | + for chunk in chunks: |
| 90 | + self.retriever.add_text(chunk) |
| 91 | + print("add_document done!") |
| 92 | + |
| 93 | + def chat(self, question: str = '', top_k: int = 5): |
| 94 | + contents = self.retriever.retrieve(query=question, top_k=top_k) |
| 95 | + contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents]) |
| 96 | + content = '\n'.join([content['text'] for content in contents]) |
| 97 | + result, history = self.llm.chat(question, [], content) |
| 98 | + return result, history, contents, question |
25 | 99 |
|
26 | 100 |
|
27 | 101 | # ========================== Config Start==================== |
28 | | -# # 创建全局配置实例 |
29 | | -# config = ConfigLoader(config_path="config_local.json") |
30 | | -# app_config = ApplicationConfig() |
31 | | -# |
32 | | -# llm_model = config.get_config('models.llm') |
33 | | -# embedding_model = config.get_config('models.embedding') |
34 | | -# reranker_model = config.get_config('models.reranker') |
35 | | -# |
36 | | -# # 加载配置 |
37 | | -# app_config.docs_path = config.get_config('paths.docs') |
38 | | -# retriever_config = DenseRetrieverConfig( |
39 | | -# model_name_or_path=embedding_model["path"], |
40 | | -# dim=1024, |
41 | | -# index_path=config.get_config('index') |
42 | | -# ) |
43 | | -# rerank_config = BgeRerankerConfig( |
44 | | -# model_name_or_path=reranker_model["path"], |
45 | | -# ) |
46 | 102 | app_config = ApplicationConfig() |
47 | | -app_config.docs_path = r"G:\Projects\TrustRAG\data\docs" |
48 | | -app_config.llm_model_path = r"G:\pretrained_models\llm\glm-4-9b-chat" |
| 103 | +app_config.docs_path = r"/data/users/searchgpt/yq/TrustRAG/data/docs" |
| 104 | +app_config.llm_model_path = r"/data/users/searchgpt/pretrained_models/Qwen3-4B" |
49 | 105 | retriever_config = DenseRetrieverConfig( |
50 | | - model_name_or_path=r"G:\pretrained_models\mteb\bge-large-zh-v1.5", |
| 106 | + model_name_or_path=r"/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5", |
51 | 107 | dim=1024, |
52 | | - index_path=r'G:\Projects\TrustRAG\examples\retrievers\dense_cache' |
| 108 | + index_path=r'/data/users/searchgpt/yq/TrustRAG/examples/retrievers/dense_cache' |
53 | 109 | ) |
54 | 110 | rerank_config = BgeRerankerConfig( |
55 | | - model_name_or_path=r"G:\pretrained_models\mteb\bge-reranker-large" |
| 111 | + model_name_or_path=r"/data/users/searchgpt/pretrained_models/bge-reranker-large" |
56 | 112 | ) |
57 | 113 |
|
58 | 114 | app_config.retriever_config = retriever_config |
59 | 115 | app_config.rerank_config = rerank_config |
60 | 116 | application = RagApplication(app_config) |
61 | 117 | application.init_vector_store() |
62 | 118 |
|
63 | | - |
64 | 119 | # ========================== Config End==================== |
65 | 120 |
|
66 | | - |
67 | 121 | # 创建北京时区的变量 |
68 | 122 | beijing_tz = pytz.timezone("Asia/Shanghai") |
69 | 123 | IGNORE_FILE_LIST = [".DS_Store"] |
|
0 commit comments