Skip to content

Commit 31da4b7

Browse files
authored
Merge pull request #46 from codefuse-ai/modelcache_dev_mm
update readme,update mm cache Storage logic
2 parents fedf74d + 021a88b commit 31da4b7

File tree

5 files changed

+122
-81
lines changed

5 files changed

+122
-81
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<div align="center">
22
<h1>
3-
Codefuse-ModelCache
3+
ModelCache
44
</h1>
55
</div>
66

@@ -262,7 +262,7 @@ In ModelCache, we adopted the main idea of GPTCache, includes core modules: ada
262262
- [ ] Support ElasticSearch
263263
### Vector Storage
264264
- [ ] Adapts Faiss storage in multimodal scenarios.
265-
### Rank能力
265+
### Ranking
266266
- [ ] Add ranking model to refine the order of data after embedding recall.
267267
### Service
268268
- [ ] Supports FastAPI.

README_CN.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<div align="center">
22
<h1>
3-
Codefuse-ModelCache
3+
ModelCache
44
</h1>
55
</div>
66

modelcache_mm/manager/scalar_data/sql_storage_sqlite.py

+38-51
Original file line numberDiff line numberDiff line change
@@ -16,47 +16,34 @@ def __init__(
1616
self.create()
1717

1818
def create(self):
19-
# answer_table_sql = """CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` (
20-
# `id` bigint(20) NOT NULL AUTO_INCREMENT comment '主键',
21-
# `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
22-
# `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
23-
# `question` text NOT NULL comment 'question',
24-
# `answer` text NOT NULL comment 'answer',
25-
# `answer_type` int(11) NOT NULL comment 'answer_type',
26-
# `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count',
27-
# `model` varchar(1000) NOT NULL comment 'model',
28-
# `embedding_data` blob NOT NULL comment 'embedding_data',
29-
# PRIMARY KEY(`id`)
30-
# ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_llm_answer';
31-
# """
32-
answer_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_llm_answer (
33-
id INTEGER PRIMARY KEY AUTOINCREMENT,
34-
gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
35-
gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
36-
question TEXT NOT NULL,
37-
answer TEXT NOT NULL,
38-
answer_type INTEGER NOT NULL,
39-
hit_count INTEGER NOT NULL DEFAULT 0,
40-
model VARCHAR(1000) NOT NULL,
41-
embedding_data BLOB NOT NULL
42-
);
43-
"""
19+
# answer_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_llm_answer (
20+
# id INTEGER PRIMARY KEY AUTOINCREMENT,
21+
# gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
22+
# gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
23+
# question TEXT NOT NULL,
24+
# answer TEXT NOT NULL,
25+
# answer_type INTEGER NOT NULL,
26+
# hit_count INTEGER NOT NULL DEFAULT 0,
27+
# model VARCHAR(1000) NOT NULL,
28+
# embedding_data BLOB NOT NULL
29+
# );
30+
# """
31+
32+
answer_table_sql = """CREATE TABLE IF NOT EXISTS `open_cache_mm_answer` (
33+
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
34+
`gmt_create` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
35+
`gmt_modified` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
36+
`question_text` TEXT NOT NULL,
37+
`image_url` VARCHAR(2048) NOT NULL,
38+
`answer` TEXT NOT NULL,
39+
`answer_type` INTEGER NOT NULL,
40+
`hit_count` INTEGER NOT NULL DEFAULT 0,
41+
`model` VARCHAR(1000) NOT NULL,
42+
`image_raw` BLOB DEFAULT NULL,
43+
`image_id` VARCHAR(1000) DEFAULT NULL
44+
);
45+
"""
4446

45-
# log_table_sql = """CREATE TABLE IF NOT EXISTS `modelcache_query_log` (
46-
# `id` bigint(20) NOT NULL AUTO_INCREMENT comment '主键',
47-
# `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
48-
# `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
49-
# `error_code` int(11) NOT NULL comment 'errorCode',
50-
# `error_desc` varchar(1000) NOT NULL comment 'errorDesc',
51-
# `cache_hit` varchar(100) NOT NULL comment 'cacheHit',
52-
# `delta_time` float NOT NULL comment 'delta_time',
53-
# `model` varchar(1000) NOT NULL comment 'model',
54-
# `query` text NOT NULL comment 'query',
55-
# `hit_query` text NOT NULL comment 'hitQuery',
56-
# `answer` text NOT NULL comment 'answer',
57-
# PRIMARY KEY(`id`)
58-
# ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log';
59-
# """
6047
log_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_query_log (
6148
id INTEGER PRIMARY KEY AUTOINCREMENT,
6249
gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -85,19 +72,19 @@ def create(self):
8572

8673
def _insert(self, data: List):
8774
answer = data[0]
88-
question = data[1]
89-
embedding_data = data[2]
90-
model = data[3]
75+
text = data[1]
76+
image_url = data[2]
77+
image_id = data[3]
78+
model = data[4]
9179
answer_type = 0
92-
embedding_data = embedding_data.tobytes()
9380

94-
table_name = "modelcache_llm_answer"
95-
insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data) VALUES (?, ?, ?, ?, ?)".format(table_name)
81+
table_name = "open_cache_mm_answer"
82+
insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (?, ?, ?, ?, ?, ?)".format(table_name)
9683

9784
conn = sqlite3.connect(self._url)
9885
try:
9986
cursor = conn.cursor()
100-
values = (question, answer, answer_type, model, embedding_data)
87+
values = (text, image_url, image_id, answer, answer_type, model)
10188
cursor.execute(insert_sql, values)
10289
conn.commit()
10390
id = cursor.lastrowid
@@ -141,7 +128,7 @@ def insert_query_resp(self, query_resp, **kwargs):
141128
conn.close()
142129

143130
def get_data_by_id(self, key: int):
144-
table_name = "modelcache_llm_answer"
131+
table_name = "open_cache_mm_answer"
145132
query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key)
146133
conn = sqlite3.connect(self._url)
147134
try:
@@ -160,7 +147,7 @@ def get_data_by_id(self, key: int):
160147
return None
161148

162149
def update_hit_count_by_id(self, primary_id: int):
163-
table_name = "modelcache_llm_answer"
150+
table_name = "open_cache_mm_answer"
164151
update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id)
165152

166153
conn = sqlite3.connect(self._url)
@@ -178,7 +165,7 @@ def get_ids(self, deleted=True):
178165
pass
179166

180167
def mark_deleted(self, keys):
181-
table_name = "modelcache_llm_answer"
168+
table_name = "open_cache_mm_answer"
182169
delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys]))
183170
conn = sqlite3.connect(self._url)
184171
try:
@@ -193,7 +180,7 @@ def mark_deleted(self, keys):
193180
return delete_count
194181

195182
def model_deleted(self, model_name):
196-
table_name = "modelcache_llm_answer"
183+
table_name = "open_cache_mm_answer"
197184
delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name)
198185
conn = sqlite3.connect(self._url)
199186
try:
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
from typing import List
4+
import numpy as np
5+
from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
6+
from modelcache_mm.utils import import_faiss
7+
import_faiss()
8+
import faiss # pylint: disable=C0413
9+
10+
11+
class Faiss(VectorBase):
12+
def __init__(self,
13+
index_file_path,
14+
dimension: int = 0,
15+
top_k: int = 1
16+
):
17+
self._dimension = dimension
18+
self._index_file_path = index_file_path
19+
self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2)
20+
self._top_k = top_k
21+
if os.path.isfile(index_file_path):
22+
self._index = faiss.read_index(index_file_path)
23+
24+
def add(self, datas: List[VectorData], model=None, mm_type=None):
25+
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
26+
np_data = np.array(data_array).astype("float32")
27+
ids = np.array(id_array)
28+
print('insert_np_data: {}'.format(np_data))
29+
print('insert_np_data: {}'.format(np_data.shape))
30+
self._index.add_with_ids(np_data, ids)
31+
32+
def search(self, data: np.ndarray, top_k: int, model, mm_type='mm'):
33+
if self._index.ntotal == 0:
34+
return None
35+
if top_k == -1:
36+
top_k = self._top_k
37+
np_data = np.array(data).astype("float32").reshape(1, -1)
38+
dist, ids = self._index.search(np_data, top_k)
39+
ids = [int(i) for i in ids[0]]
40+
return list(zip(dist[0], ids))
41+
42+
def rebuild_col(self, ids=None):
43+
try:
44+
self._index.reset()
45+
except Exception as e:
46+
return f"An error occurred during index rebuild: {e}"
47+
48+
def rebuild(self, ids=None):
49+
return True
50+
51+
def delete(self, ids):
52+
ids_to_remove = np.array(ids)
53+
self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove)))
54+
55+
def create(self, model=None, mm_type=None):
56+
pass
57+
# collection_name_model = get_mm_index_name(model, mm_type)
58+
# try:
59+
# index_prefix = get_mm_index_prefix(model, mm_type)
60+
# self.create_index(collection_name_model, mm_type, index_prefix)
61+
# except Exception as e:
62+
# raise ValueError(str(e))
63+
# return 'success'
64+
65+
def flush(self):
66+
faiss.write_index(self._index, self._index_file_path)
67+
68+
def close(self):
69+
self.flush()
70+
71+
def rebuild_idx(self, model):
72+
pass
73+
74+
def count(self):
75+
return self._index.ntotal

modelcache_mm/manager/vector_data/manager.py

+6-27
Original file line numberDiff line numberDiff line change
@@ -98,36 +98,15 @@ def get(name, **kwargs):
9898
t_dimension=t_dimension,
9999
)
100100
elif name == "faiss":
101-
from modelcache.manager.vector_data.faiss import Faiss
102-
101+
from modelcache_mm.manager.vector_data.faiss import Faiss
103102
dimension = kwargs.get("dimension", DIMENSION)
104-
index_path = kwargs.pop("index_path", FAISS_INDEX_PATH)
105103
VectorBase.check_dimension(dimension)
106-
vector_base = Faiss(
107-
index_file_path=index_path, dimension=dimension, top_k=top_k
108-
)
109-
elif name == "chromadb":
110-
from modelcache.manager.vector_data.chroma import Chromadb
111-
112-
client_settings = kwargs.get("client_settings", None)
113-
persist_directory = kwargs.get("persist_directory", None)
114-
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
115-
vector_base = Chromadb(
116-
client_settings=client_settings,
117-
persist_directory=persist_directory,
118-
collection_name=collection_name,
119-
top_k=top_k,
120-
)
121-
elif name == "hnswlib":
122-
from modelcache.manager.vector_data.hnswlib_store import Hnswlib
123104

124-
dimension = kwargs.get("dimension", DIMENSION)
125-
index_path = kwargs.pop("index_path", "./hnswlib_index.bin")
126-
max_elements = kwargs.pop("max_elements", 100000)
127-
VectorBase.check_dimension(dimension)
128-
vector_base = Hnswlib(
129-
index_file_path=index_path, dimension=dimension,
130-
top_k=top_k, max_elements=max_elements
105+
index_path = kwargs.pop("index_path", FAISS_INDEX_PATH)
106+
vector_base = Faiss(
107+
index_file_path=index_path,
108+
dimension=dimension,
109+
top_k=top_k
131110
)
132111
else:
133112
raise NotFoundError("vector store", name)

0 commit comments

Comments
 (0)