1
+ # -*- coding: utf-8 -*-
2
+ import time
1
3
import uvicorn
2
4
import asyncio
3
- from fastapi import FastAPI , Request
5
+ import logging
6
+ import configparser
7
+ import json
8
+ from fastapi import FastAPI , Request , HTTPException
9
+ from pydantic import BaseModel
10
+ from concurrent .futures import ThreadPoolExecutor
11
+ from starlette .responses import PlainTextResponse
12
+
13
+ from modelcache import cache
14
+ from modelcache .adapter import adapter
15
+ from modelcache .manager import CacheBase , VectorBase , get_data_manager
16
+ from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
17
+ from modelcache .processor .pre import query_multi_splicing
18
+ from modelcache .processor .pre import insert_multi_splicing
19
+ from modelcache .utils .model_filter import model_blacklist_filter
20
+ from modelcache .embedding import Data2VecAudio
21
+
22
+ #创建一个FastAPI实例
23
+ app = FastAPI ()
24
+
25
+ class RequestData (BaseModel ):
26
+ type : str
27
+ scope : dict = None
28
+ query : str = None
29
+ chat_info : dict = None
30
+ remove_type : str = None
31
+ id_list : list = []
32
+
33
+ data2vec = Data2VecAudio ()
34
+ mysql_config = configparser .ConfigParser ()
35
+ mysql_config .read ('modelcache/config/mysql_config.ini' )
36
+
37
+ milvus_config = configparser .ConfigParser ()
38
+ milvus_config .read ('modelcache/config/milvus_config.ini' )
39
+
40
+ # redis_config = configparser.ConfigParser()
41
+ # redis_config.read('modelcache/config/redis_config.ini')
42
+
43
+ # 初始化datamanager
44
+ data_manager = get_data_manager (
45
+ CacheBase ("mysql" , config = mysql_config ),
46
+ VectorBase ("milvus" , dimension = data2vec .dimension , milvus_config = milvus_config )
47
+ )
48
+
49
+ # # 使用redis初始化datamanager
50
+ # data_manager = get_data_manager(
51
+ # CacheBase("mysql", config=mysql_config),
52
+ # VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
53
+ # )
54
+
55
+ cache .init (
56
+ embedding_func = data2vec .to_embeddings ,
57
+ data_manager = data_manager ,
58
+ similarity_evaluation = SearchDistanceEvaluation (),
59
+ query_pre_embedding_func = query_multi_splicing ,
60
+ insert_pre_embedding_func = insert_multi_splicing ,
61
+ )
62
+
63
+ executor = ThreadPoolExecutor (max_workers = 6 )
64
+
65
+ # 异步保存查询信息
66
+ async def save_query_info (result , model , query , delta_time_log ):
67
+ loop = asyncio .get_running_loop ()
68
+ await loop .run_in_executor (executor , cache .data_manager .save_query_resp , result , model , json .dumps (query , ensure_ascii = False ), delta_time_log )
69
+
70
+
71
+
72
+ @app .get ("/welcome" , response_class = PlainTextResponse )
73
+ async def first_fastapi ():
74
+ return "hello, modelcache!"
75
+
76
+ @app .post ("/modelcache" )
77
+ async def user_backend (request_data : RequestData ):
78
+ # param parsing
79
+ try :
80
+ request_type = request_data .type
81
+ model = None
82
+ if request_data .scope :
83
+ model = request_data .scope .get ('model' , '' ).replace ('-' ,'_' ).replace ('.' , '_' )
84
+ query = request_data .query
85
+ chat_info = request_data .chat_info
86
+
87
+ except Exception as e :
88
+ result = {"errorCode" : 103 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
89
+ return result
90
+
91
+ # model filter
92
+ filter_resp = model_blacklist_filter (model , request_type )
93
+ if isinstance (filter_resp , dict ):
94
+ return filter_resp
95
+
96
+ if request_type == 'query' :
97
+ try :
98
+ start_time = time .time ()
99
+ response = adapter .ChatCompletion .create_query (scope = {"model" : model }, query = query )
100
+ delta_time = f"{ round (time .time () - start_time , 2 )} s"
101
+
102
+ if response is None :
103
+ result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
104
+ # elif response in ['adapt_query_exception']:
105
+ elif isinstance (response , str ):
106
+ result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
107
+ "hit_query" : '' , "answer" : '' }
108
+ else :
109
+ answer = response ['data' ]
110
+ hit_query = response ['hitQuery' ]
111
+ result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
112
+
113
+ delta_time_log = round (time .time () - start_time , 2 )
114
+ asyncio .create_task (save_query_info (result , model , query , delta_time_log ))
115
+ return result
116
+ except Exception as e :
117
+ result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
118
+ "hit_query" : '' , "answer" : '' }
119
+ logging .info (f'result: { str (result )} ' )
120
+ return result
121
+
122
+ if request_type == 'insert' :
123
+ try :
124
+ response = adapter .ChatCompletion .create_insert (model = model , chat_info = chat_info )
125
+ if response == 'success' :
126
+ return {"errorCode" : 0 , "errorDesc" : "" , "writeStatus" : "success" }
127
+ else :
128
+ return {"errorCode" : 301 , "errorDesc" : response , "writeStatus" : "exception" }
129
+ except Exception as e :
130
+ return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
131
+
132
+ if request_type == 'remove' :
133
+ response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .remove_type , id_list = request_data .id_list )
134
+ if not isinstance (response , dict ):
135
+ return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
136
+
137
+ state = response .get ('status' )
138
+ if state == 'success' :
139
+ return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
140
+ else :
141
+ return {"errorCode" : 402 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
142
+
143
+ if request_type == 'register' :
144
+ response = adapter .ChatCompletion .create_register (model = model )
145
+ if response in ['create_success' , 'already_exists' ]:
146
+ return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
147
+ else :
148
+ return {"errorCode" : 502 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
149
+
150
+ # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
151
+ if __name__ == '__main__' :
152
+ uvicorn .run (app , host = '0.0.0.0' , port = 5001 )
0 commit comments