9
9
from pydantic import BaseModel
10
10
from concurrent .futures import ThreadPoolExecutor
11
11
from starlette .responses import PlainTextResponse
12
+ import functools
12
13
13
14
from modelcache import cache
14
15
from modelcache .adapter import adapter
@@ -63,9 +64,10 @@ class RequestData(BaseModel):
63
64
executor = ThreadPoolExecutor (max_workers = 6 )
64
65
65
66
# 异步保存查询信息
66
- async def save_query_info (result , model , query , delta_time_log ):
67
+ async def save_query_info_fastapi (result , model , query , delta_time_log ):
67
68
loop = asyncio .get_running_loop ()
68
- await loop .run_in_executor (executor , cache .data_manager .save_query_resp , result , model , json .dumps (query , ensure_ascii = False ), delta_time_log )
69
+ func = functools .partial (cache .data_manager .save_query_resp , result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log )
70
+ await loop .run_in_executor (None , func )
69
71
70
72
71
73
@@ -74,20 +76,53 @@ async def first_fastapi():
74
76
return "hello, modelcache!"
75
77
76
78
@app .post ("/modelcache" )
77
- async def user_backend (request_data : RequestData ):
78
- # param parsing
79
+ async def user_backend (request : Request ):
79
80
try :
80
- request_type = request_data .type
81
+ raw_body = await request .body ()
82
+ # 解析字符串为JSON对象
83
+ if isinstance (raw_body , bytes ):
84
+ raw_body = raw_body .decode ("utf-8" )
85
+ if isinstance (raw_body , str ):
86
+ try :
87
+ # 尝试将字符串解析为JSON对象
88
+ request_data = json .loads (raw_body )
89
+ except json .JSONDecodeError :
90
+ # 如果无法解析,返回格式错误
91
+ raise HTTPException (status_code = 400 , detail = "Invalid JSON format" )
92
+ else :
93
+ request_data = raw_body
94
+
95
+ # 确保request_data是字典对象
96
+ if isinstance (request_data , str ):
97
+ try :
98
+ request_data = json .loads (request_data )
99
+ except json .JSONDecodeError :
100
+ raise HTTPException (status_code = 400 , detail = "Invalid JSON format" )
101
+
102
+ request_type = request_data .get ('type' )
81
103
model = None
82
- if request_data .scope :
83
- model = request_data .scope .get ('model' , '' ).replace ('-' ,'_' ).replace ('.' , '_' )
84
- query = request_data .query
85
- chat_info = request_data .chat_info
104
+ if 'scope' in request_data :
105
+ model = request_data ['scope' ].get ('model' , '' ).replace ('-' , '_' ).replace ('.' , '_' )
106
+ query = request_data .get ('query' )
107
+ chat_info = request_data .get ('chat_info' )
108
+
109
+ if not request_type or request_type not in ['query' , 'insert' , 'remove' , 'detox' ]:
110
+ raise HTTPException (status_code = 400 , detail = "Type exception, should be one of ['query', 'insert', 'remove', 'detox']" )
86
111
87
112
except Exception as e :
88
- result = {"errorCode" : 103 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
113
+ request_data = raw_body if 'raw_body' in locals () else None
114
+ result = {
115
+ "errorCode" : 103 ,
116
+ "errorDesc" : str (e ),
117
+ "cacheHit" : False ,
118
+ "delta_time" : 0 ,
119
+ "hit_query" : '' ,
120
+ "answer" : '' ,
121
+ "para_dict" : request_data
122
+ }
89
123
return result
90
124
125
+
91
126
# model filter
92
127
filter_resp = model_blacklist_filter (model , request_type )
93
128
if isinstance (filter_resp , dict ):
@@ -101,8 +136,7 @@ async def user_backend(request_data: RequestData):
101
136
102
137
if response is None :
103
138
result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
104
- # elif response in ['adapt_query_exception']:
105
- elif isinstance (response , str ):
139
+ elif response in ['adapt_query_exception' ]:
106
140
result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
107
141
"hit_query" : '' , "answer" : '' }
108
142
else :
@@ -111,7 +145,7 @@ async def user_backend(request_data: RequestData):
111
145
result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
112
146
113
147
delta_time_log = round (time .time () - start_time , 2 )
114
- asyncio .create_task (save_query_info (result , model , query , delta_time_log ))
148
+ asyncio .create_task (save_query_info_fastapi (result , model , query , delta_time_log ))
115
149
return result
116
150
except Exception as e :
117
151
result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
@@ -130,7 +164,7 @@ async def user_backend(request_data: RequestData):
130
164
return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
131
165
132
166
if request_type == 'remove' :
133
- response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .remove_type , id_list = request_data .id_list )
167
+ response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .get ( " remove_type" ) , id_list = request_data .get ( " id_list" ) )
134
168
if not isinstance (response , dict ):
135
169
return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
136
170
0 commit comments