Skip to content

Commit d29299d

Browse files
authored
Merge pull request #58 from charleschile/fastapi-support
Fast API 接口能力
2 parents 8940b74 + c000a42 commit d29299d

File tree

7 files changed

+361
-2
lines changed

7 files changed

+361
-2
lines changed

examples/flask/llms_cache/data_insert.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def run():
1414
res = requests.post(url, headers=headers, json=json.dumps(data))
1515
res_text = res.text
1616

17+
print("data_insert:", res.status_code, res_text)
1718

1819
if __name__ == '__main__':
1920
run()

examples/flask/llms_cache/data_query.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def run():
1414
res = requests.post(url, headers=headers, json=json.dumps(data))
1515
res_text = res.text
1616

17+
print("data_query:", res.status_code, res_text)
1718

1819
if __name__ == '__main__':
1920
run()

examples/flask/llms_cache/data_query_long.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def run():
1919
res = requests.post(url, headers=headers, json=json.dumps(data))
2020
res_text = res.text
2121

22+
print("data_query_long:", res.status_code, res_text)
2223

2324
if __name__ == '__main__':
2425
run()

fastapi4modelcache.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
import uvicorn
4+
import asyncio
5+
import logging
6+
import configparser
7+
import json
8+
from fastapi import FastAPI, Request, HTTPException
9+
from pydantic import BaseModel
10+
from concurrent.futures import ThreadPoolExecutor
11+
from starlette.responses import PlainTextResponse
12+
import functools
13+
14+
from modelcache import cache
15+
from modelcache.adapter import adapter
16+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
17+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18+
from modelcache.processor.pre import query_multi_splicing
19+
from modelcache.processor.pre import insert_multi_splicing
20+
from modelcache.utils.model_filter import model_blacklist_filter
21+
from modelcache.embedding import Data2VecAudio
22+
23+
#创建一个FastAPI实例
24+
app = FastAPI()
25+
26+
class RequestData(BaseModel):
27+
type: str
28+
scope: dict = None
29+
query: str = None
30+
chat_info: dict = None
31+
remove_type: str = None
32+
id_list: list = []
33+
34+
data2vec = Data2VecAudio()
35+
mysql_config = configparser.ConfigParser()
36+
mysql_config.read('modelcache/config/mysql_config.ini')
37+
38+
milvus_config = configparser.ConfigParser()
39+
milvus_config.read('modelcache/config/milvus_config.ini')
40+
41+
# redis_config = configparser.ConfigParser()
42+
# redis_config.read('modelcache/config/redis_config.ini')
43+
44+
# 初始化datamanager
45+
data_manager = get_data_manager(
46+
CacheBase("mysql", config=mysql_config),
47+
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)
48+
)
49+
50+
# # 使用redis初始化datamanager
51+
# data_manager = get_data_manager(
52+
# CacheBase("mysql", config=mysql_config),
53+
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
54+
# )
55+
56+
cache.init(
57+
embedding_func=data2vec.to_embeddings,
58+
data_manager=data_manager,
59+
similarity_evaluation=SearchDistanceEvaluation(),
60+
query_pre_embedding_func=query_multi_splicing,
61+
insert_pre_embedding_func=insert_multi_splicing,
62+
)
63+
64+
executor = ThreadPoolExecutor(max_workers=6)
65+
66+
# 异步保存查询信息
67+
async def save_query_info(result, model, query, delta_time_log):
68+
loop = asyncio.get_running_loop()
69+
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
70+
await loop.run_in_executor(None, func)
71+
72+
73+
74+
@app.get("/welcome", response_class=PlainTextResponse)
75+
async def first_fastapi():
76+
return "hello, modelcache!"
77+
78+
@app.post("/modelcache")
79+
async def user_backend(request: Request):
80+
try:
81+
raw_body = await request.body()
82+
# 解析字符串为JSON对象
83+
if isinstance(raw_body, bytes):
84+
raw_body = raw_body.decode("utf-8")
85+
if isinstance(raw_body, str):
86+
try:
87+
# 尝试将字符串解析为JSON对象
88+
request_data = json.loads(raw_body)
89+
except json.JSONDecodeError as e:
90+
# 如果无法解析,返回格式错误
91+
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
92+
"answer": ''}
93+
asyncio.create_task(save_query_info(result, model='', query='', delta_time_log=0))
94+
raise HTTPException(status_code=101, detail="Invalid JSON format")
95+
else:
96+
request_data = raw_body
97+
98+
# 确保request_data是字典对象
99+
if isinstance(request_data, str):
100+
try:
101+
request_data = json.loads(request_data)
102+
except json.JSONDecodeError:
103+
raise HTTPException(status_code=101, detail="Invalid JSON format")
104+
105+
request_type = request_data.get('type')
106+
model = None
107+
if 'scope' in request_data:
108+
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
109+
query = request_data.get('query')
110+
chat_info = request_data.get('chat_info')
111+
112+
if not request_type or request_type not in ['query', 'insert', 'remove', 'register']:
113+
result = {"errorCode": 102,
114+
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
115+
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
116+
asyncio.create_task(save_query_info(result, model=model, query='', delta_time_log=0))
117+
raise HTTPException(status_code=102, detail="Type exception, should be one of ['query', 'insert', 'remove', 'register']")
118+
119+
except Exception as e:
120+
request_data = raw_body if 'raw_body' in locals() else None
121+
result = {
122+
"errorCode": 103,
123+
"errorDesc": str(e),
124+
"cacheHit": False,
125+
"delta_time": 0,
126+
"hit_query": '',
127+
"answer": '',
128+
"para_dict": request_data
129+
}
130+
return result
131+
132+
133+
# model filter
134+
filter_resp = model_blacklist_filter(model, request_type)
135+
if isinstance(filter_resp, dict):
136+
return filter_resp
137+
138+
if request_type == 'query':
139+
try:
140+
start_time = time.time()
141+
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
142+
delta_time = f"{round(time.time() - start_time, 2)}s"
143+
144+
if response is None:
145+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
146+
elif response in ['adapt_query_exception']:
147+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
148+
"hit_query": '', "answer": ''}
149+
else:
150+
answer = response['data']
151+
hit_query = response['hitQuery']
152+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
153+
154+
delta_time_log = round(time.time() - start_time, 2)
155+
asyncio.create_task(save_query_info(result, model, query, delta_time_log))
156+
return result
157+
except Exception as e:
158+
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
159+
"hit_query": '', "answer": ''}
160+
logging.info(f'result: {str(result)}')
161+
return result
162+
163+
if request_type == 'insert':
164+
try:
165+
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
166+
if response == 'success':
167+
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
168+
else:
169+
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
170+
except Exception as e:
171+
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
172+
173+
if request_type == 'remove':
174+
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
175+
if not isinstance(response, dict):
176+
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
177+
178+
state = response.get('status')
179+
if state == 'success':
180+
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
181+
else:
182+
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
183+
184+
if request_type == 'register':
185+
response = adapter.ChatCompletion.create_register(model=model)
186+
if response in ['create_success', 'already_exists']:
187+
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
188+
else:
189+
return {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
190+
191+
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
192+
if __name__ == '__main__':
193+
uvicorn.run(app, host='0.0.0.0', port=5000)

fastapi4modelcache_demo.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
import uvicorn
4+
import asyncio
5+
import logging
6+
# import configparser
7+
import json
8+
from fastapi import FastAPI, Request, HTTPException
9+
from pydantic import BaseModel
10+
from concurrent.futures import ThreadPoolExecutor
11+
from starlette.responses import PlainTextResponse
12+
import functools
13+
14+
from modelcache import cache
15+
from modelcache.adapter import adapter
16+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
17+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18+
from modelcache.processor.pre import query_multi_splicing
19+
from modelcache.processor.pre import insert_multi_splicing
20+
from modelcache.utils.model_filter import model_blacklist_filter
21+
from modelcache.embedding import Data2VecAudio
22+
23+
# 创建一个FastAPI实例
24+
app = FastAPI()
25+
26+
class RequestData(BaseModel):
27+
type: str
28+
scope: dict = None
29+
query: str = None
30+
chat_info: list = None
31+
remove_type: str = None
32+
id_list: list = []
33+
34+
data2vec = Data2VecAudio()
35+
36+
data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
37+
38+
cache.init(
39+
embedding_func=data2vec.to_embeddings,
40+
data_manager=data_manager,
41+
similarity_evaluation=SearchDistanceEvaluation(),
42+
query_pre_embedding_func=query_multi_splicing,
43+
insert_pre_embedding_func=insert_multi_splicing,
44+
)
45+
46+
executor = ThreadPoolExecutor(max_workers=6)
47+
48+
# 异步保存查询信息
49+
async def save_query_info_fastapi(result, model, query, delta_time_log):
50+
loop = asyncio.get_running_loop()
51+
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
52+
await loop.run_in_executor(None, func)
53+
54+
55+
56+
@app.get("/welcome", response_class=PlainTextResponse)
57+
async def first_fastapi():
58+
return "hello, modelcache!"
59+
60+
@app.post("/modelcache")
61+
async def user_backend(request: Request):
62+
try:
63+
raw_body = await request.body()
64+
# 解析字符串为JSON对象
65+
if isinstance(raw_body, bytes):
66+
raw_body = raw_body.decode("utf-8")
67+
if isinstance(raw_body, str):
68+
try:
69+
# 尝试将字符串解析为JSON对象
70+
request_data = json.loads(raw_body)
71+
except json.JSONDecodeError:
72+
# 如果无法解析,返回格式错误
73+
raise HTTPException(status_code=400, detail="Invalid JSON format")
74+
else:
75+
request_data = raw_body
76+
77+
# 确保request_data是字典对象
78+
if isinstance(request_data, str):
79+
try:
80+
request_data = json.loads(request_data)
81+
except json.JSONDecodeError:
82+
raise HTTPException(status_code=400, detail="Invalid JSON format")
83+
84+
request_type = request_data.get('type')
85+
model = None
86+
if 'scope' in request_data:
87+
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
88+
query = request_data.get('query')
89+
chat_info = request_data.get('chat_info')
90+
91+
if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']:
92+
raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']")
93+
94+
except Exception as e:
95+
request_data = raw_body if 'raw_body' in locals() else None
96+
result = {
97+
"errorCode": 103,
98+
"errorDesc": str(e),
99+
"cacheHit": False,
100+
"delta_time": 0,
101+
"hit_query": '',
102+
"answer": '',
103+
"para_dict": request_data
104+
}
105+
return result
106+
107+
108+
# model filter
109+
filter_resp = model_blacklist_filter(model, request_type)
110+
if isinstance(filter_resp, dict):
111+
return filter_resp
112+
113+
if request_type == 'query':
114+
try:
115+
start_time = time.time()
116+
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
117+
delta_time = f"{round(time.time() - start_time, 2)}s"
118+
119+
if response is None:
120+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
121+
elif response in ['adapt_query_exception']:
122+
# elif isinstance(response, str):
123+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
124+
"hit_query": '', "answer": ''}
125+
else:
126+
answer = response['data']
127+
hit_query = response['hitQuery']
128+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
129+
130+
delta_time_log = round(time.time() - start_time, 2)
131+
asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log))
132+
return result
133+
except Exception as e:
134+
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
135+
"hit_query": '', "answer": ''}
136+
logging.info(f'result: {str(result)}')
137+
return result
138+
139+
if request_type == 'insert':
140+
try:
141+
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
142+
if response == 'success':
143+
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
144+
else:
145+
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
146+
except Exception as e:
147+
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
148+
149+
if request_type == 'remove':
150+
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
151+
if not isinstance(response, dict):
152+
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
153+
154+
state = response.get('status')
155+
if state == 'success':
156+
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
157+
else:
158+
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
159+
160+
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
161+
if __name__ == '__main__':
162+
uvicorn.run(app, host='0.0.0.0', port=5000)

modelcache/manager/scalar_data/sql_storage_sqlite.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ def insert_query_resp(self, query_resp, **kwargs):
100100
hit_query = json.dumps(hit_query, ensure_ascii=False)
101101

102102
table_name = "modelcache_query_log"
103-
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
104-
103+
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?)".format(table_name)
105104
conn = sqlite3.connect(self._url)
106105
try:
107106
cursor = conn.cursor()

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ transformers==4.38.2
1212
faiss-cpu==1.7.4
1313
redis==5.0.1
1414
modelscope==1.14.0
15+
fastapi==0.115.5
16+
uvicorn==0.32.0

0 commit comments

Comments
 (0)