Skip to content

Commit 3dad91c

Browse files
authored
Merge pull request #60 from powerli2002/add-chroma
Add feature : add chromadb support as a vector database
2 parents e603a1c + f3c7657 commit 3dad91c

File tree

10 files changed

+223
-6
lines changed

10 files changed

+223
-6
lines changed

flask4modelcache.py

+6
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,16 @@ def response_hitquery(cache_resp):
4141
# redis_config = configparser.ConfigParser()
4242
# redis_config.read('modelcache/config/redis_config.ini')
4343

44+
# chromadb_config = configparser.ConfigParser()
45+
# chromadb_config.read('modelcache/config/chromadb_config.ini')
4446

4547
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
4648
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
4749

50+
51+
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
52+
# VectorBase("chromadb", dimension=data2vec.dimension, chromadb_config=chromadb_config))
53+
4854
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
4955
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))
5056

modelcache/config/chromadb_config.ini

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[chromadb]
2+
persist_directory=''
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import logging
5+
from modelcache.manager.vector_data.base import VectorBase, VectorData
6+
from modelcache.utils import import_chromadb, import_torch
7+
8+
import_torch()
9+
import_chromadb()
10+
11+
import chromadb
12+
13+
14+
class Chromadb(VectorBase):
15+
16+
def __init__(
17+
self,
18+
persist_directory="./chromadb",
19+
top_k: int = 1,
20+
):
21+
self.collection_name = "modelcache"
22+
self.top_k = top_k
23+
24+
self._client = chromadb.PersistentClient(path=persist_directory)
25+
self._collection = None
26+
27+
def mul_add(self, datas: List[VectorData], model=None):
28+
collection_name_model = self.collection_name + '_' + model
29+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
30+
31+
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
32+
self._collection.add(embeddings=data_array, ids=id_array)
33+
34+
def search(self, data: np.ndarray, top_k: int = -1, model=None):
35+
collection_name_model = self.collection_name + '_' + model
36+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
37+
38+
if self._collection.count() == 0:
39+
return []
40+
if top_k == -1:
41+
top_k = self.top_k
42+
results = self._collection.query(
43+
query_embeddings=[data.tolist()],
44+
n_results=top_k,
45+
include=["distances"],
46+
)
47+
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
48+
49+
def rebuild(self, ids=None):
50+
pass
51+
52+
def delete(self, ids, model=None):
53+
try:
54+
collection_name_model = self.collection_name + '_' + model
55+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
56+
# 查询集合中实际存在的 ID
57+
ids_str = [str(x) for x in ids]
58+
existing_ids = set(self._collection.get(ids=ids_str).ids)
59+
60+
# 删除存在的 ID
61+
if existing_ids:
62+
self._collection.delete(list(existing_ids))
63+
64+
# 返回实际删除的条目数量
65+
return len(existing_ids)
66+
67+
except Exception as e:
68+
logging.error('Error during deletion: {}'.format(e))
69+
raise ValueError(str(e))
70+
71+
def rebuild_col(self, model):
72+
collection_name_model = self.collection_name + '_' + model
73+
74+
# 检查集合是否存在,如果存在则删除
75+
collections = self._client.list_collections()
76+
if any(col.name == collection_name_model for col in collections):
77+
self._client.delete_collection(collection_name_model)
78+
else:
79+
return 'model collection not found, please check!'
80+
81+
try:
82+
self._client.create_collection(collection_name_model)
83+
except Exception as e:
84+
logging.info(f'rebuild_collection: {e}')
85+
raise ValueError(str(e))
86+
87+
def flush(self):
88+
# chroma无flush方法
89+
pass
90+
91+
def close(self):
92+
pass

modelcache/manager/vector_data/manager.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,11 @@ def get(name, **kwargs):
102102
elif name == "chromadb":
103103
from modelcache.manager.vector_data.chroma import Chromadb
104104

105-
client_settings = kwargs.get("client_settings", None)
106-
persist_directory = kwargs.get("persist_directory", None)
107-
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
105+
chromadb_config = kwargs.get("chromadb_config", None)
106+
persist_directory = chromadb_config.get('chromadb','persist_directory')
107+
108108
vector_base = Chromadb(
109-
client_settings=client_settings,
110109
persist_directory=persist_directory,
111-
collection_name=collection_name,
112110
top_k=top_k,
113111
)
114112
elif name == "hnswlib":

modelcache/utils/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,7 @@ def import_pillow():
7373

7474
def import_redis():
7575
_check_library("redis")
76+
77+
78+
def import_chromadb():
79+
_check_library("chromadb", package="chromadb")
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[chromadb]
2+
persist_directory=./chromadb
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import logging
5+
from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
6+
from modelcache_mm.utils import import_chromadb, import_torch
7+
from modelcache_mm.utils.index_util import get_mm_index_name
8+
9+
import_torch()
10+
import_chromadb()
11+
12+
import chromadb
13+
14+
15+
class Chromadb(VectorBase):
16+
17+
def __init__(
18+
self,
19+
persist_directory="./chromadb",
20+
top_k: int = 1,
21+
):
22+
self.top_k = top_k
23+
24+
self._client = chromadb.PersistentClient(path=persist_directory)
25+
self._collection = None
26+
27+
def create(self, model=None, mm_type=None):
28+
try:
29+
collection_name_model = get_mm_index_name(model, mm_type)
30+
# collection_name_model = self.collection_name + '_' + model
31+
self._client.get_or_create_collection(name=collection_name_model)
32+
except Exception as e:
33+
raise ValueError(str(e))
34+
35+
def add(self, datas: List[VectorData], model=None, mm_type=None):
36+
collection_name_model = get_mm_index_name(model, mm_type)
37+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
38+
39+
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
40+
self._collection.add(embeddings=data_array, ids=id_array)
41+
42+
def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type='mm'):
43+
collection_name_model = get_mm_index_name(model, mm_type)
44+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
45+
46+
if self._collection.count() == 0:
47+
return []
48+
if top_k == -1:
49+
top_k = self.top_k
50+
results = self._collection.query(
51+
query_embeddings=[data.tolist()],
52+
n_results=top_k,
53+
include=["distances"],
54+
)
55+
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
56+
57+
def delete(self, ids, model=None, mm_type=None):
58+
try:
59+
collection_name_model = get_mm_index_name(model, mm_type)
60+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
61+
# 查询集合中实际存在的 ID
62+
ids_str = [str(x) for x in ids]
63+
existing_ids = set(self._collection.get(ids=ids_str).ids)
64+
65+
# 删除存在的 ID
66+
if existing_ids:
67+
self._collection.delete(list(existing_ids))
68+
69+
# 返回实际删除的条目数量
70+
return len(existing_ids)
71+
72+
except Exception as e:
73+
logging.error('Error during deletion: {}'.format(e))
74+
raise ValueError(str(e))
75+
76+
def rebuild_idx(self, model, mm_type=None):
77+
collection_name_model = get_mm_index_name(model, mm_type)
78+
79+
# 检查集合是否存在,如果存在则删除
80+
collections = self._client.list_collections()
81+
if any(col.name == collection_name_model for col in collections):
82+
self._client.delete_collection(collection_name_model)
83+
else:
84+
return 'model collection not found, please check!'
85+
86+
try:
87+
self._client.create_collection(collection_name_model)
88+
except Exception as e:
89+
logging.info(f'rebuild_collection: {e}')
90+
raise ValueError(str(e))
91+
92+
def rebuild(self, ids=None):
93+
pass
94+
95+
def flush(self):
96+
pass
97+
98+
def close(self):
99+
pass

modelcache_mm/manager/vector_data/manager.py

+9
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def get(name, **kwargs):
108108
dimension=dimension,
109109
top_k=top_k
110110
)
111+
elif name == "chromadb":
112+
from modelcache_mm.manager.vector_data.chroma import Chromadb
113+
114+
chromadb_config = kwargs.get("chromadb_config", None)
115+
persist_directory = chromadb_config.get('chromadb', 'persist_directory')
116+
vector_base = Chromadb(
117+
persist_directory=persist_directory,
118+
top_k=top_k,
119+
)
111120
else:
112121
raise NotFoundError("vector store", name)
113122
return vector_base

modelcache_mm/utils/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,7 @@ def import_pillow():
7373

7474
def import_redis():
7575
_check_library("redis")
76+
77+
78+
def import_chromadb():
79+
_check_library("chromadb", package="chromadb")

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ faiss-cpu==1.7.4
1313
redis==5.0.1
1414
modelscope==1.14.0
1515
fastapi==0.115.5
16-
uvicorn==0.32.0
16+
uvicorn==0.32.0
17+
chromadb==0.5.23

0 commit comments

Comments
 (0)