Skip to content

Commit 642e1ca

Browse files
committed
add fastapi4modelcache
1 parent fc003de commit 642e1ca

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

fastapi4modelcache.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,152 @@
1+
# -*- coding: utf-8 -*-
2+
import time
13
import uvicorn
24
import asyncio
3-
from fastapi import FastAPI, Request
5+
import logging
6+
import configparser
7+
import json
8+
from fastapi import FastAPI, Request, HTTPException
9+
from pydantic import BaseModel
10+
from concurrent.futures import ThreadPoolExecutor
11+
from starlette.responses import PlainTextResponse
12+
13+
from modelcache import cache
14+
from modelcache.adapter import adapter
15+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
16+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
17+
from modelcache.processor.pre import query_multi_splicing
18+
from modelcache.processor.pre import insert_multi_splicing
19+
from modelcache.utils.model_filter import model_blacklist_filter
20+
from modelcache.embedding import Data2VecAudio
21+
22+
#创建一个FastAPI实例
23+
app = FastAPI()
24+
25+
class RequestData(BaseModel):
26+
type: str
27+
scope: dict = None
28+
query: str = None
29+
chat_info: dict = None
30+
remove_type: str = None
31+
id_list: list = []
32+
33+
data2vec = Data2VecAudio()
34+
mysql_config = configparser.ConfigParser()
35+
mysql_config.read('modelcache/config/mysql_config.ini')
36+
37+
milvus_config = configparser.ConfigParser()
38+
milvus_config.read('modelcache/config/milvus_config.ini')
39+
40+
# redis_config = configparser.ConfigParser()
41+
# redis_config.read('modelcache/config/redis_config.ini')
42+
43+
# 初始化datamanager
44+
data_manager = get_data_manager(
45+
CacheBase("mysql", config=mysql_config),
46+
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)
47+
)
48+
49+
# # 使用redis初始化datamanager
50+
# data_manager = get_data_manager(
51+
# CacheBase("mysql", config=mysql_config),
52+
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
53+
# )
54+
55+
cache.init(
56+
embedding_func=data2vec.to_embeddings,
57+
data_manager=data_manager,
58+
similarity_evaluation=SearchDistanceEvaluation(),
59+
query_pre_embedding_func=query_multi_splicing,
60+
insert_pre_embedding_func=insert_multi_splicing,
61+
)
62+
63+
executor = ThreadPoolExecutor(max_workers=6)
64+
65+
# 异步保存查询信息
66+
async def save_query_info(result, model, query, delta_time_log):
67+
loop = asyncio.get_running_loop()
68+
await loop.run_in_executor(executor, cache.data_manager.save_query_resp, result, model, json.dumps(query, ensure_ascii=False), delta_time_log)
69+
70+
71+
72+
@app.get("/welcome", response_class=PlainTextResponse)
73+
async def first_fastapi():
74+
return "hello, modelcache!"
75+
76+
@app.post("/modelcache")
77+
async def user_backend(request_data: RequestData):
78+
# param parsing
79+
try:
80+
request_type = request_data.type
81+
model = None
82+
if request_data.scope:
83+
model = request_data.scope.get('model', '').replace('-','_').replace('.', '_')
84+
query = request_data.query
85+
chat_info = request_data.chat_info
86+
87+
except Exception as e:
88+
result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
89+
return result
90+
91+
# model filter
92+
filter_resp = model_blacklist_filter(model, request_type)
93+
if isinstance(filter_resp, dict):
94+
return filter_resp
95+
96+
if request_type == 'query':
97+
try:
98+
start_time = time.time()
99+
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
100+
delta_time = f"{round(time.time() - start_time, 2)}s"
101+
102+
if response is None:
103+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
104+
# elif response in ['adapt_query_exception']:
105+
elif isinstance(response, str):
106+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
107+
"hit_query": '', "answer": ''}
108+
else:
109+
answer = response['data']
110+
hit_query = response['hitQuery']
111+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
112+
113+
delta_time_log = round(time.time() - start_time, 2)
114+
asyncio.create_task(save_query_info(result, model, query, delta_time_log))
115+
return result
116+
except Exception as e:
117+
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
118+
"hit_query": '', "answer": ''}
119+
logging.info(f'result: {str(result)}')
120+
return result
121+
122+
if request_type == 'insert':
123+
try:
124+
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
125+
if response == 'success':
126+
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
127+
else:
128+
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
129+
except Exception as e:
130+
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
131+
132+
if request_type == 'remove':
133+
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.remove_type, id_list=request_data.id_list)
134+
if not isinstance(response, dict):
135+
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
136+
137+
state = response.get('status')
138+
if state == 'success':
139+
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
140+
else:
141+
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
142+
143+
if request_type == 'register':
144+
response = adapter.ChatCompletion.create_register(model=model)
145+
if response in ['create_success', 'already_exists']:
146+
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
147+
else:
148+
return {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
149+
150+
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
151+
if __name__ == '__main__':
152+
uvicorn.run(app, host='0.0.0.0', port=5001)

0 commit comments

Comments
 (0)