Skip to content

Commit 7d7f245

Browse files
committed
update
1 parent 2e77636 commit 7d7f245

File tree

4 files changed

+101
-107
lines changed

4 files changed

+101
-107
lines changed

app_local_model.py

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,60 +10,114 @@
1010
import sys
1111

1212
sys.path.append(".")
13-
import os
1413
import shutil
1514
import time
1615
import gradio as gr
1716
import loguru
1817
import pandas as pd
1918
from datetime import datetime
2019
import pytz
21-
from trustrag.config.config_loader import ConfigLoader
22-
from trustrag.applications.rag import RagApplication, ApplicationConfig
2320
from trustrag.modules.reranker.bge_reranker import BgeRerankerConfig
2421
from trustrag.modules.retrieval.dense_retriever import DenseRetrieverConfig
22+
import os
23+
from trustrag.modules.citation.match_citation import MatchCitation
24+
from trustrag.modules.document.common_parser import CommonParser
25+
from trustrag.modules.generator.llm import Qwen3Chat
26+
from trustrag.modules.reranker.bge_reranker import BgeReranker
27+
from trustrag.modules.retrieval.dense_retriever import DenseRetriever
28+
from trustrag.modules.document.chunk import TextChunker
29+
from trustrag.modules.vector.embedding import SentenceTransformerEmbedding
30+
31+
32+
class ApplicationConfig():
33+
def __init__(self):
34+
self.retriever_config = None
35+
self.rerank_config = None
36+
37+
38+
class RagApplication():
39+
def __init__(self, config):
40+
self.config = config
41+
self.parser = CommonParser()
42+
self.embedding_generator = SentenceTransformerEmbedding(self.config.retriever_config.model_name_or_path)
43+
self.retriever = DenseRetriever(self.config.retriever_config, self.embedding_generator)
44+
self.reranker = BgeReranker(self.config.rerank_config)
45+
self.llm = Qwen3Chat(self.config.llm_model_path)
46+
self.mc = MatchCitation()
47+
self.tc = TextChunker()
48+
49+
def init_vector_store(self):
50+
"""
51+
52+
"""
53+
print("init_vector_store ... ")
54+
all_paragraphs = []
55+
all_chunks = []
56+
for filename in os.listdir(self.config.docs_path):
57+
file_path = os.path.join(self.config.docs_path, filename)
58+
try:
59+
paragraphs = self.parser.parse(file_path)
60+
all_paragraphs.append(paragraphs)
61+
except:
62+
pass
63+
print("chunking for paragraphs")
64+
for paragraphs in all_paragraphs:
65+
# 确保paragraphs是list,并处理其中的元素
66+
if isinstance(paragraphs, list) and paragraphs:
67+
if isinstance(paragraphs[0], dict):
68+
# list[dict] -> list[str]
69+
text_list = [' '.join(str(value) for value in item.values()) for item in paragraphs]
70+
else:
71+
# 已经是list[str]
72+
text_list = [str(item) for item in paragraphs]
73+
else:
74+
# 处理其他情况
75+
text_list = [str(paragraphs)] if paragraphs else []
76+
77+
chunks = self.tc.get_chunks(text_list, 256)
78+
all_chunks.extend(chunks)
79+
80+
self.retriever.build_from_texts(all_chunks)
81+
print("init_vector_store done! ")
82+
self.retriever.save_index(self.config.retriever_config.index_path)
83+
84+
def load_vector_store(self):
85+
self.retriever.load_index(self.config.retriever_config.index_path)
86+
87+
def add_document(self, file_path):
88+
chunks = self.parser.parse(file_path)
89+
for chunk in chunks:
90+
self.retriever.add_text(chunk)
91+
print("add_document done!")
92+
93+
def chat(self, question: str = '', top_k: int = 5):
94+
contents = self.retriever.retrieve(query=question, top_k=top_k)
95+
contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
96+
content = '\n'.join([content['text'] for content in contents])
97+
result, history = self.llm.chat(question, [], content)
98+
return result, history, contents, question
2599

26100

27101
# ========================== Config Start====================
28-
# # 创建全局配置实例
29-
# config = ConfigLoader(config_path="config_local.json")
30-
# app_config = ApplicationConfig()
31-
#
32-
# llm_model = config.get_config('models.llm')
33-
# embedding_model = config.get_config('models.embedding')
34-
# reranker_model = config.get_config('models.reranker')
35-
#
36-
# # 加载配置
37-
# app_config.docs_path = config.get_config('paths.docs')
38-
# retriever_config = DenseRetrieverConfig(
39-
# model_name_or_path=embedding_model["path"],
40-
# dim=1024,
41-
# index_path=config.get_config('index')
42-
# )
43-
# rerank_config = BgeRerankerConfig(
44-
# model_name_or_path=reranker_model["path"],
45-
# )
46102
app_config = ApplicationConfig()
47-
app_config.docs_path = r"G:\Projects\TrustRAG\data\docs"
48-
app_config.llm_model_path = r"G:\pretrained_models\llm\glm-4-9b-chat"
103+
app_config.docs_path = r"/data/users/searchgpt/yq/TrustRAG/data/docs"
104+
app_config.llm_model_path = r"/data/users/searchgpt/pretrained_models/Qwen3-4B"
49105
retriever_config = DenseRetrieverConfig(
50-
model_name_or_path=r"G:\pretrained_models\mteb\bge-large-zh-v1.5",
106+
model_name_or_path=r"/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
51107
dim=1024,
52-
index_path=r'G:\Projects\TrustRAG\examples\retrievers\dense_cache'
108+
index_path=r'/data/users/searchgpt/yq/TrustRAG/examples/retrievers/dense_cache'
53109
)
54110
rerank_config = BgeRerankerConfig(
55-
model_name_or_path=r"G:\pretrained_models\mteb\bge-reranker-large"
111+
model_name_or_path=r"/data/users/searchgpt/pretrained_models/bge-reranker-large"
56112
)
57113

58114
app_config.retriever_config = retriever_config
59115
app_config.rerank_config = rerank_config
60116
application = RagApplication(app_config)
61117
application.init_vector_store()
62118

63-
64119
# ========================== Config End====================
65120

66-
67121
# 创建北京时区的变量
68122
beijing_tz = pytz.timezone("Asia/Shanghai")
69123
IGNORE_FILE_LIST = [".DS_Store"]

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ xgboost
5050
bm25s
5151
jieba
5252
accelerate
53-
FlagEmbedding
5453
chardet
5554
openpyxl
5655
protobuf

trustrag/applications/rag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from trustrag.modules.reranker.bge_reranker import BgeReranker
1616
from trustrag.modules.retrieval.dense_retriever import DenseRetriever
1717
from trustrag.modules.document.chunk import TextChunker
18-
from trustrag.modules.vector.embedding import FlagModelEmbedding
18+
from trustrag.modules.vector.embedding import SentenceTransformerEmbedding
1919
class ApplicationConfig():
2020
def __init__(self):
2121
self.retriever_config = None
@@ -26,7 +26,7 @@ class RagApplication():
2626
def __init__(self, config):
2727
self.config = config
2828
self.parser = CommonParser()
29-
self.embedding_generator = FlagModelEmbedding(self.config.retriever_config.model_name_or_path)
29+
self.embedding_generator = SentenceTransformerEmbedding(self.config.retriever_config.model_name_or_path)
3030
self.retriever = DenseRetriever(self.config.retriever_config,self.embedding_generator)
3131
self.reranker = BgeReranker(self.config.rerank_config)
3232
self.llm = GLM4Chat(self.config.llm_model_path)

trustrag/modules/vector/embedding.py

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2-
from typing import List, Optional
2+
from typing import List, Dict
3+
from typing import Optional
34

45
import numpy as np
6+
import requests
57
import torch
6-
from FlagEmbedding import FlagAutoModel
78
from openai import OpenAI
89
from sentence_transformers import SentenceTransformer
910
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -12,6 +13,20 @@
1213
from trustrag.modules.vector.base import EmbeddingGenerator
1314

1415

16+
class SentenceTransformerEmbedding(EmbeddingGenerator):
17+
def __init__(
18+
self,
19+
model_name_or_path: str = "sentence-transformers/multi-qa-mpnet-base-cos-v1",
20+
device: str = None
21+
):
22+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
23+
self.model = SentenceTransformer(model_name_or_path, device=self.device)
24+
self.embedding_size = self.model.get_sentence_embedding_dimension()
25+
26+
def generate_embeddings(self, texts: List[str]) -> np.ndarray:
27+
return self.model.encode(texts, show_progress_bar=False)
28+
29+
1530
class OpenAIEmbedding(EmbeddingGenerator):
1631
def __init__(
1732
self,
@@ -36,20 +51,6 @@ def generate_embeddings(self, texts: List[str]) -> np.ndarray:
3651
return np.array([data.embedding for data in response.data])
3752

3853

39-
class SentenceTransformerEmbedding(EmbeddingGenerator):
40-
def __init__(
41-
self,
42-
model_name_or_path: str = "sentence-transformers/multi-qa-mpnet-base-cos-v1",
43-
device: str = None
44-
):
45-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
46-
self.model = SentenceTransformer(model_name_or_path, device=self.device)
47-
self.embedding_size = self.model.get_sentence_embedding_dimension()
48-
49-
def generate_embeddings(self, texts: List[str]) -> np.ndarray:
50-
return self.model.encode(texts, show_progress_bar=False)
51-
52-
5354
class HuggingFaceEmbedding(EmbeddingGenerator):
5455
def __init__(
5556
self,
@@ -113,66 +114,6 @@ def generate_embeddings(self, texts: List[str]) -> np.ndarray:
113114
return np.array(embeddings)
114115

115116

116-
class FlagModelEmbedding(EmbeddingGenerator):
117-
def __init__(
118-
self,
119-
model_name: str = "BAAI/bge-base-en-v1.5",
120-
query_instruction: Optional[str] = "Represent this sentence for searching relevant passages:",
121-
use_fp16: bool = True,
122-
device: str = None
123-
):
124-
"""
125-
Initialize FlagModel embedding generator.
126-
127-
Args:
128-
model_name (str): Name or path of the model
129-
query_instruction (str, optional): Instruction prefix for queries
130-
use_fp16 (bool): Whether to use FP16 for inference
131-
device (str, optional): Device to run the model on
132-
"""
133-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
134-
self.model = FlagAutoModel.from_finetuned(
135-
model_name,
136-
query_instruction_for_retrieval=query_instruction,
137-
use_fp16=use_fp16,
138-
devices=self.device
139-
)
140-
# if self.device == "cuda":
141-
# self.model.to(device)
142-
143-
def generate_embeddings(self, texts: List[str]) -> np.ndarray:
144-
"""
145-
Generate embeddings for a list of texts.
146-
147-
Args:
148-
texts (List[str]): List of texts to generate embeddings for
149-
150-
Returns:
151-
np.ndarray: Array of embeddings
152-
"""
153-
embeddings = self.model.encode(texts)
154-
return np.array(embeddings)
155-
156-
def compute_similarity(self, embeddings1: np.ndarray, embeddings2: np.ndarray) -> np.ndarray:
157-
"""
158-
Compute similarity matrix between two sets of embeddings using inner product.
159-
160-
Args:
161-
embeddings1 (np.ndarray): First set of embeddings
162-
embeddings2 (np.ndarray): Second set of embeddings
163-
164-
Returns:
165-
np.ndarray: Similarity matrix
166-
"""
167-
return embeddings1 @ embeddings2.T
168-
169-
170-
import requests
171-
import numpy as np
172-
from typing import List, Dict, Any
173-
from abc import ABC, abstractmethod
174-
175-
176117
class CustomServerEmbedding(EmbeddingGenerator):
177118
"""
178119
Implementation of EmbeddingGenerator that uses a remote embedding service.

0 commit comments

Comments
 (0)