Skip to content

Commit f40893c

Browse files
authored
refactor(models): unify model configuration and standardize LLM client factory (#283)
1 parent 0fd98a5 commit f40893c

File tree

21 files changed

+641
-141
lines changed

21 files changed

+641
-141
lines changed

backend/api-gateway/src/main/java/com/datamate/gateway/ApiGatewayApplication.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public RouteLocator customRouteLocator(RouteLocatorBuilder builder) {
4242
.uri("http://datamate-backend-python:18000"))
4343

4444
// 知识图谱RAG服务路由
45-
.route("graph-rag", r -> r.path("/api/rag/**")
45+
.route("python-service", r -> r.path("/api/rag/**", "api/models/**")
4646
.uri("http://datamate-backend-python:18000"))
4747

4848
.route("deer-flow-frontend", r -> r.path("/chat/**")

backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
*/
1313
@Getter
1414
@Setter
15-
@TableName("t_model_config")
15+
@TableName("t_models")
1616
@Builder
1717
@ToString
1818
@NoArgsConstructor
@@ -47,4 +47,9 @@ public class ModelConfig extends BaseEntity<String> {
4747
* 是否默认:1-默认,0-非默认
4848
*/
4949
private Boolean isDefault;
50+
51+
/**
52+
* 是否删除:1-已删除,0-未删除
53+
*/
54+
private Boolean isDeleted;
5055
}

frontend/vite.config.ts

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { defineConfig } from "vite";
1+
import {defineConfig} from "vite";
22
import react from "@vitejs/plugin-react";
33
import tailwindcss from "@tailwindcss/vite";
44
import path from "path"; // 需要安装 Node.js 的类型声明(@types/node)
@@ -12,30 +12,55 @@ export default defineConfig({
1212
},
1313
},
1414
server: {
15-
// headers: {
16-
// "Access-Control-Allow-Origin": "*",
17-
// "access-control-allow-headers":
18-
// "Origin, X-Requested-With, Content-Type, Accept",
19-
// },
20-
proxy: {
21-
"^/api": {
22-
target: "http://localhost:8080", // 本地后端服务地址
15+
host: "0.0.0.0",
16+
proxy: (() => {
17+
const pythonProxyConfig = {
18+
target: "http://localhost:18000",
2319
changeOrigin: true,
2420
secure: false,
25-
rewrite: (path) => path.replace(/^\/api/, "/api"),
26-
configure: (proxy, options) => {
27-
// proxy 是 'http-proxy' 的实例
28-
proxy.on("proxyReq", (proxyReq, req, res) => {
29-
// 可以在这里修改请求头
30-
proxyReq.removeHeader("referer");
31-
proxyReq.removeHeader("origin");
21+
configure: (proxy: { on: (event: string, handler: (arg: unknown) => void) => void }) => {
22+
proxy.on("proxyReq", (proxyReq: unknown) => {
23+
(proxyReq as { removeHeader: (name: string) => void }).removeHeader("referer");
24+
(proxyReq as { removeHeader: (name: string) => void }).removeHeader("origin");
3225
});
33-
proxy.on("proxyRes", (proxyRes, req, res) => {
34-
delete proxyRes.headers["set-cookie"];
35-
proxyRes.headers["cookies"] = ""; // 清除 cookies 头
26+
proxy.on("proxyRes", (proxyRes: unknown) => {
27+
const res = proxyRes as { headers: Record<string, unknown> };
28+
delete res.headers["set-cookie"];
29+
res.headers["cookies"] = "";
3630
});
3731
},
38-
},
39-
},
32+
};
33+
34+
const javaProxyConfig = {
35+
target: "http://localhost:8080",
36+
changeOrigin: true,
37+
secure: false,
38+
configure: (proxy: { on: (event: string, handler: (arg: unknown) => void) => void }) => {
39+
proxy.on("proxyReq", (proxyReq: unknown) => {
40+
(proxyReq as { removeHeader: (name: string) => void }).removeHeader("referer");
41+
(proxyReq as { removeHeader: (name: string) => void }).removeHeader("origin");
42+
});
43+
proxy.on("proxyRes", (proxyRes: unknown) => {
44+
const res = proxyRes as { headers: Record<string, unknown> };
45+
delete res.headers["set-cookie"];
46+
res.headers["cookies"] = "";
47+
});
48+
},
49+
};
50+
51+
// Python 服务: rag, synthesis, annotation, evaluation, models
52+
const pythonPaths = ["rag", "synthesis", "annotation", "data-collection", "evaluation", "models"];
53+
// Java 服务: data-management, knowledge-base
54+
const javaPaths = ["data-management", "knowledge-base", "operators"];
55+
56+
const proxy: Record<string, object> = {};
57+
for (const p of pythonPaths) {
58+
proxy[`/api/${p}`] = pythonProxyConfig;
59+
}
60+
for (const p of javaPaths) {
61+
proxy[`/api/${p}`] = javaProxyConfig;
62+
}
63+
return proxy;
64+
})(),
4065
},
4166
});

runtime/datamate-python/app/db/models/model_config.py renamed to runtime/datamate-python/app/db/models/models.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
from sqlalchemy import Column, String, Integer, TIMESTAMP, select
1+
from sqlalchemy import Boolean, Column, String, TIMESTAMP
22

33
from app.db.models.base_entity import BaseEntity
44

55

6-
async def get_model_by_id(db_session, model_id: str):
7-
"""根据 ID 获取单个模型配置。"""
8-
result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id))
9-
model_config = result.scalar_one_or_none()
10-
return model_config
6+
class Models(BaseEntity):
7+
"""模型配置表,对应表 t_models。模型为系统级配置,RAG/生成等按 ID 引用时不受数据权限过滤。
118
12-
class ModelConfig(BaseEntity):
13-
"""模型配置表,对应表 t_model_config
14-
15-
CREATE TABLE IF NOT EXISTS t_model_config (
9+
CREATE TABLE IF NOT EXISTS t_models (
1610
id VARCHAR(36) PRIMARY KEY COMMENT '主键ID',
1711
model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)',
1812
provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)',
@@ -29,7 +23,7 @@ class ModelConfig(BaseEntity):
2923
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表';
3024
"""
3125

32-
__tablename__ = "t_model_config"
26+
__tablename__ = "t_models"
3327

3428
id = Column(String(36), primary_key=True, index=True, comment="主键ID")
3529
model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)")
@@ -38,9 +32,9 @@ class ModelConfig(BaseEntity):
3832
api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)")
3933
type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)")
4034

41-
# 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool
42-
is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用")
43-
is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认")
35+
is_enabled = Column(Boolean, nullable=False, default=True, comment="是否启用")
36+
is_default = Column(Boolean, nullable=False, default=False, comment="是否默认")
37+
is_deleted = Column(Boolean, nullable=False, default=False, comment="是否删除")
4438

4539
__table_args__ = (
4640
# 与 DDL 中的 uk_model_provider 保持一致

runtime/datamate-python/app/module/evaluation/interface/evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ async def create_evaluation_task(
8080
if existing_task.scalar_one_or_none():
8181
raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists")
8282

83-
model_config = await get_model_by_id(db, request.eval_config.model_id)
84-
if not model_config:
83+
models = await get_model_by_id(db, request.eval_config.model_id)
84+
if not models:
8585
raise HTTPException(status_code=400, detail=f"Model with id '{request.eval_config.model_id}' not found")
8686

8787
# 创建评估任务
@@ -96,7 +96,7 @@ async def create_evaluation_task(
9696
eval_prompt=request.eval_prompt,
9797
eval_config=json.dumps({
9898
"modelId": request.eval_config.model_id,
99-
"modelName": model_config.model_name,
99+
"modelName": models.model_name,
100100
"dimensions": request.eval_config.dimensions,
101101
}),
102102
status=TaskStatus.PENDING.value,

runtime/datamate-python/app/module/evaluation/service/evaluation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_eval_prompt(self, item: EvaluationItem) -> str:
4343

4444
async def execute(self):
4545
eval_config = json.loads(self.task.eval_config)
46-
model_config = await get_model_by_id(self.db, eval_config.get("modelId"))
46+
models = await get_model_by_id(self.db, eval_config.get("modelId"))
4747
semaphore = asyncio.Semaphore(10)
4848
files = (await self.db.execute(
4949
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
@@ -55,7 +55,7 @@ async def execute(self):
5555
for file in files:
5656
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
5757
tasks = [
58-
self.evaluate_item(model_config, item, semaphore)
58+
self.evaluate_item(models, item, semaphore)
5959
for item in items
6060
]
6161
await asyncio.gather(*tasks, return_exceptions=True)
@@ -64,13 +64,13 @@ async def execute(self):
6464
self.task.eval_process = evaluated_count / total
6565
await self.db.commit()
6666

67-
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
67+
async def evaluate_item(self, models, item: EvaluationItem, semaphore: asyncio.Semaphore):
6868
async with semaphore:
6969
max_try = 3
7070
while max_try > 0:
7171
prompt_text = self.get_eval_prompt(item)
7272
resp_text = await asyncio.to_thread(
73-
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
73+
call_openai_style_model, models.base_url, models.api_key, models.model_name,
7474
prompt_text,
7575
)
7676
resp_text = extract_json_substring(resp_text)

runtime/datamate-python/app/module/generation/service/generation_service.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from app.module.shared.common.document_loaders import load_documents
2525
from app.module.shared.common.text_split import DocumentSplitter
2626
from app.module.shared.util.model_chat import extract_json_substring
27-
from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client
27+
from app.module.shared.llm import LLMFactory
28+
from app.module.system.service.common_service import get_model_by_id
2829

2930

3031
def _filter_docs(split_docs, chunk_size):
@@ -171,8 +172,12 @@ async def _process_single_file(
171172
# 为本文件构建模型 client
172173
question_model = await get_model_by_id(self.db, question_cfg.model_id)
173174
answer_model = await get_model_by_id(self.db, answer_cfg.model_id)
174-
question_chat = get_chat_client(question_model)
175-
answer_chat = get_chat_client(answer_model)
175+
question_chat = LLMFactory.create_chat(
176+
question_model.model_name, question_model.base_url, question_model.api_key
177+
)
178+
answer_chat = LLMFactory.create_chat(
179+
answer_model.model_name, answer_model.base_url, answer_model.api_key
180+
)
176181

177182
# 分批次从 DB 读取并处理 chunk
178183
batch_size = 100
@@ -356,7 +361,7 @@ async def _generate_questions_for_one_chunk(
356361
loop = asyncio.get_running_loop()
357362
raw_answer = await loop.run_in_executor(
358363
None,
359-
chat,
364+
LLMFactory.invoke_sync,
360365
question_chat,
361366
prompt,
362367
)
@@ -400,7 +405,7 @@ async def process_single_question(question: str):
400405
loop = asyncio.get_running_loop()
401406
answer = await loop.run_in_executor(
402407
None,
403-
chat,
408+
LLMFactory.invoke_sync,
404409
answer_chat,
405410
prompt_local,
406411
)

runtime/datamate-python/app/module/rag/interface/rag_interface.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
router = APIRouter(prefix="/rag", tags=["rag"])
1010

1111
@router.post("/process/{knowledge_base_id}")
12-
async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depends(get_db)):
12+
async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()):
1313
"""
1414
Process all unprocessed files in a knowledge base.
1515
"""
1616
try:
17-
await RAGService(db).init_graph_rag(knowledge_base_id)
17+
await rag_service.init_graph_rag(knowledge_base_id)
1818
return StandardResponse(
1919
code=200,
2020
message="Processing started for knowledge base.",
@@ -24,12 +24,11 @@ async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depe
2424
raise HTTPException(status_code=500, detail=str(e))
2525

2626
@router.post("/query")
27-
async def query_knowledge_graph(payload: QueryRequest, db: AsyncSession = Depends(get_db)):
27+
async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()):
2828
"""
2929
Query the knowledge graph with the given query text and knowledge base ID.
3030
"""
3131
try:
32-
rag_service = RAGService(db)
3332
result = await rag_service.query_rag(payload.query, payload.knowledge_base_id)
3433
return StandardResponse(code=200, message="success", data=result)
3534
except HTTPException:

runtime/datamate-python/app/module/rag/service/rag_service.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import asyncio
33
from typing import Optional, Sequence
44

5-
from fastapi import BackgroundTasks, Depends
5+
from fastapi import Depends
66
from sqlalchemy import select
77
from sqlalchemy.ext.asyncio import AsyncSession
88

99
from app.core.logging import get_logger
1010
from app.db.models.dataset_management import DatasetFiles
1111
from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase
12-
from app.db.models.model_config import ModelConfig
1312
from app.db.session import get_db, AsyncSessionLocal
1413
from app.module.shared.common.document_loaders import load_documents
1514
from .graph_rag import (
@@ -18,7 +17,8 @@
1817
build_llm_model_func,
1918
initialize_rag,
2019
)
21-
from ...system.service.common_service import get_embedding_dimension, get_openai_client
20+
from app.module.shared.llm import LLMFactory
21+
from ...system.service.common_service import get_model_by_id
2222

2323
logger = get_logger(__name__)
2424

@@ -27,10 +27,10 @@ class RAGService:
2727
def __init__(
2828
self,
2929
db: AsyncSession = Depends(get_db),
30-
background_tasks: BackgroundTasks | None = None,
30+
3131
):
3232
self.db = db
33-
self.background_tasks = background_tasks
33+
self.background_tasks = None
3434
self.rag = None
3535

3636
async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFile]:
@@ -44,8 +44,8 @@ async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFil
4444

4545
async def init_graph_rag(self, knowledge_base_id: str):
4646
kb = await self._get_knowledge_base(knowledge_base_id)
47-
embedding_model = await self._get_model_config(kb.embedding_model)
48-
chat_model = await self._get_model_config(kb.chat_model)
47+
embedding_model = await self._get_models(kb.embedding_model)
48+
chat_model = await self._get_models(kb.chat_model)
4949

5050
llm_callable = await build_llm_model_func(
5151
chat_model.model_name, chat_model.base_url, chat_model.api_key
@@ -54,7 +54,9 @@ async def init_graph_rag(self, knowledge_base_id: str):
5454
embedding_model.model_name,
5555
embedding_model.base_url,
5656
embedding_model.api_key,
57-
embedding_dim=get_embedding_dimension(get_openai_client(embedding_model)),
57+
embedding_dim=LLMFactory.get_embedding_dimension(
58+
embedding_model.model_name, embedding_model.base_url, embedding_model.api_key
59+
),
5860
)
5961

6062
kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name)
@@ -124,14 +126,13 @@ async def _get_knowledge_base(self, knowledge_base_id: str):
124126
raise ValueError(f"Knowledge base with ID {knowledge_base_id} not found.")
125127
return knowledge_base
126128

127-
async def _get_model_config(self, model_id: Optional[str]):
129+
async def _get_models(self, model_id: Optional[str]):
128130
if not model_id:
129131
raise ValueError("Model ID is required for initializing RAG.")
130-
result = await self.db.execute(select(ModelConfig).where(ModelConfig.id == model_id))
131-
model = result.scalars().first()
132-
if not model:
133-
raise ValueError(f"Model config with ID {model_id} not found.")
134-
return model
132+
models = await get_model_by_id(self.db, model_id)
133+
if not models:
134+
raise ValueError(f"Models with ID {model_id} not found.")
135+
return models
135136

136137
async def query_rag(self, query: str, knowledge_base_id: str) -> str:
137138
if not self.rag:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# app/core/llm/__init__.py
2+
"""
3+
LangChain 模型工厂:统一创建 Chat、Embedding 及健康检查,便于各模块复用。
4+
"""
5+
from .factory import LLMFactory
6+
7+
__all__ = ["LLMFactory"]

0 commit comments

Comments
 (0)