Skip to content

Commit 3dbc951

Browse files
committed
fastapi for modelcache
1 parent 72018ce commit 3dbc951

File tree

4 files changed

+51
-14
lines changed

4 files changed

+51
-14
lines changed

examples/flask/llms_cache/data_insert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def run():
1414
res = requests.post(url, headers=headers, json=json.dumps(data))
1515
res_text = res.text
1616

17+
print("data_insert:", res.status_code, res_text)
1718

1819
if __name__ == '__main__':
1920
run()

examples/flask/llms_cache/data_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def run():
1414
res = requests.post(url, headers=headers, json=json.dumps(data))
1515
res_text = res.text
1616

17+
print("data_query:", res.status_code, res_text)
1718

1819
if __name__ == '__main__':
1920
run()

examples/flask/llms_cache/data_query_long.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def run():
1919
res = requests.post(url, headers=headers, json=json.dumps(data))
2020
res_text = res.text
2121

22+
print("data_query_long:", res.status_code, res_text)
2223

2324
if __name__ == '__main__':
2425
run()

fastapi4modelcache.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import BaseModel
1010
from concurrent.futures import ThreadPoolExecutor
1111
from starlette.responses import PlainTextResponse
12+
import functools
1213

1314
from modelcache import cache
1415
from modelcache.adapter import adapter
@@ -63,9 +64,10 @@ class RequestData(BaseModel):
6364
executor = ThreadPoolExecutor(max_workers=6)
6465

6566
# 异步保存查询信息
66-
async def save_query_info(result, model, query, delta_time_log):
67+
async def save_query_info_fastapi(result, model, query, delta_time_log):
6768
loop = asyncio.get_running_loop()
68-
await loop.run_in_executor(executor, cache.data_manager.save_query_resp, result, model, json.dumps(query, ensure_ascii=False), delta_time_log)
69+
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
70+
await loop.run_in_executor(None, func)
6971

7072

7173

@@ -74,20 +76,53 @@ async def first_fastapi():
7476
return "hello, modelcache!"
7577

7678
@app.post("/modelcache")
77-
async def user_backend(request_data: RequestData):
78-
# param parsing
79+
async def user_backend(request: Request):
7980
try:
80-
request_type = request_data.type
81+
raw_body = await request.body()
82+
# 解析字符串为JSON对象
83+
if isinstance(raw_body, bytes):
84+
raw_body = raw_body.decode("utf-8")
85+
if isinstance(raw_body, str):
86+
try:
87+
# 尝试将字符串解析为JSON对象
88+
request_data = json.loads(raw_body)
89+
except json.JSONDecodeError:
90+
# 如果无法解析,返回格式错误
91+
raise HTTPException(status_code=400, detail="Invalid JSON format")
92+
else:
93+
request_data = raw_body
94+
95+
# 确保request_data是字典对象
96+
if isinstance(request_data, str):
97+
try:
98+
request_data = json.loads(request_data)
99+
except json.JSONDecodeError:
100+
raise HTTPException(status_code=400, detail="Invalid JSON format")
101+
102+
request_type = request_data.get('type')
81103
model = None
82-
if request_data.scope:
83-
model = request_data.scope.get('model', '').replace('-','_').replace('.', '_')
84-
query = request_data.query
85-
chat_info = request_data.chat_info
104+
if 'scope' in request_data:
105+
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
106+
query = request_data.get('query')
107+
chat_info = request_data.get('chat_info')
108+
109+
if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']:
110+
raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']")
86111

87112
except Exception as e:
88-
result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
113+
request_data = raw_body if 'raw_body' in locals() else None
114+
result = {
115+
"errorCode": 103,
116+
"errorDesc": str(e),
117+
"cacheHit": False,
118+
"delta_time": 0,
119+
"hit_query": '',
120+
"answer": '',
121+
"para_dict": request_data
122+
}
89123
return result
90124

125+
91126
# model filter
92127
filter_resp = model_blacklist_filter(model, request_type)
93128
if isinstance(filter_resp, dict):
@@ -101,8 +136,7 @@ async def user_backend(request_data: RequestData):
101136

102137
if response is None:
103138
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
104-
# elif response in ['adapt_query_exception']:
105-
elif isinstance(response, str):
139+
elif response in ['adapt_query_exception']:
106140
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
107141
"hit_query": '', "answer": ''}
108142
else:
@@ -111,7 +145,7 @@ async def user_backend(request_data: RequestData):
111145
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
112146

113147
delta_time_log = round(time.time() - start_time, 2)
114-
asyncio.create_task(save_query_info(result, model, query, delta_time_log))
148+
asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log))
115149
return result
116150
except Exception as e:
117151
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
@@ -130,7 +164,7 @@ async def user_backend(request_data: RequestData):
130164
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
131165

132166
if request_type == 'remove':
133-
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.remove_type, id_list=request_data.id_list)
167+
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
134168
if not isinstance(response, dict):
135169
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
136170

0 commit comments

Comments
 (0)