1
1
# -*- coding: utf-8 -*-
2
2
from typing import List
3
- < << << << HEAD
4
-
5
- import numpy as np
6
- from modelcache .manager .vector_data .base import VectorBase , VectorData
7
- from modelcache .utils import import_redis
8
- from redis .commands .search .query import Query
9
- from redis .commands .search .indexDefinition import IndexDefinition , IndexType
10
- from modelcache .utils .log import modelcache_log
11
-
12
- import_redis ()
13
- #
14
- # from redis.commands.search.indexDefinition import IndexDefinition, IndexType
15
- # from redis.commands.search.query import Query
16
- # from redis.commands.search.field import TagField, VectorField
17
- # from redis.client import Redis
18
- == == == =
19
3
import numpy as np
20
4
from redis .commands .search .indexDefinition import IndexDefinition , IndexType
21
5
from redis .commands .search .query import Query
28
12
from modelcache .utils .index_util import get_index_name
29
13
from modelcache .utils .index_util import get_index_prefix
30
14
import_redis ()
31
- > >> >> >> main
32
15
33
16
34
17
class RedisVectorStore (VectorBase ):
@@ -39,103 +22,25 @@ def __init__(
39
22
username : str = "" ,
40
23
password : str = "" ,
41
24
dimension : int = 0 ,
42
- << << << < HEAD
43
- collection_name : str = "gptcache" ,
44
- top_k : int = 1 ,
45
- namespace : str = "" ,
46
- ):
47
- == == == =
48
25
top_k : int = 1 ,
49
26
namespace : str = "" ,
50
27
):
51
28
if dimension <= 0 :
52
29
raise ValueError (
53
30
f"invalid `dim` param: { dimension } in the Redis vector store."
54
31
)
55
- >> >> >> > main
56
32
self ._client = Redis (
57
33
host = host , port = int (port ), username = username , password = password
58
34
)
59
35
self .top_k = top_k
60
36
self .dimension = dimension
61
- < << << << HEAD
62
- self .collection_name = collection_name
63
- self .namespace = namespace
64
- self .doc_prefix = f"{ self .namespace } doc:" # Prefix with the specified namespace
65
- self ._create_collection (collection_name )
66
- == == == =
67
37
self .namespace = namespace
68
38
self .doc_prefix = f"{ self .namespace } doc:"
69
- >> >> >> > main
70
39
71
40
def _check_index_exists (self , index_name : str ) -> bool :
72
41
"""Check if Redis index exists."""
73
42
try :
74
43
self ._client .ft (index_name ).info ()
75
- << << << < HEAD
76
- except : # pylint: disable=W0702
77
- gptcache_log .info ("Index does not exist" )
78
- return False
79
- gptcache_log .info ("Index already exists" )
80
- return True
81
-
82
- def _create_collection (self , collection_name ):
83
- if self ._check_index_exists (collection_name ):
84
- gptcache_log .info (
85
- "The %s already exists, and it will be used directly" , collection_name
86
- )
87
- else :
88
- schema = (
89
- TagField ("tag" ), # Tag Field Name
90
- VectorField (
91
- "vector" , # Vector Field Name
92
- "FLAT" ,
93
- { # Vector Index Type: FLAT or HNSW
94
- "TYPE" : "FLOAT32" , # FLOAT32 or FLOAT64
95
- "DIM" : self .dimension , # Number of Vector Dimensions
96
- "DISTANCE_METRIC" : "COSINE" , # Vector Search Distance Metric
97
- },
98
- ),
99
- )
100
- definition = IndexDefinition (
101
- prefix = [self .doc_prefix ], index_type = IndexType .HASH
102
- )
103
-
104
- # create Index
105
- self ._client .ft (collection_name ).create_index (
106
- fields = schema , definition = definition
107
- )
108
-
109
- def mul_add (self , datas : List [VectorData ]):
110
- pipe = self ._client .pipeline ()
111
-
112
- for data in datas :
113
- key : int = data .id
114
- obj = {
115
- "vector" : data .data .astype (np .float32 ).tobytes (),
116
- }
117
- pipe .hset (f"{ self .doc_prefix } { key } " , mapping = obj )
118
-
119
- pipe .execute ()
120
-
121
- def search (self , data : np .ndarray , top_k : int = - 1 ):
122
- query = (
123
- Query (
124
- f"*=>[KNN { top_k if top_k > 0 else self .top_k } @vector $vec as score]"
125
- )
126
- .sort_by ("score" )
127
- .return_fields ("id" , "score" )
128
- .paging (0 , top_k if top_k > 0 else self .top_k )
129
- .dialect (2 )
130
- )
131
- query_params = {"vec" : data .astype (np .float32 ).tobytes ()}
132
- results = (
133
- self ._client .ft (self .collection_name )
134
- .search (query , query_params = query_params )
135
- .docs
136
- )
137
- return [(float (result .score ), int (result .id [len (self .doc_prefix ):])) for result in results ]
138
- == == == =
139
44
except :
140
45
modelcache_log .info ("Index does not exist" )
141
46
return False
@@ -201,13 +106,10 @@ def search(self, data: np.ndarray, top_k: int = -1, model=None):
201
106
.docs
202
107
)
203
108
return [(float (result .distance ), int (getattr (result , id_field_name ))) for result in results ]
204
- > >> >> >> main
205
109
206
110
def rebuild (self , ids = None ) -> bool :
207
111
pass
208
112
209
- < << << << HEAD
210
- == == == =
211
113
def rebuild_col (self , model ):
212
114
index_name_model = get_index_name (model )
213
115
if self ._check_index_exists (index_name_model ):
@@ -222,14 +124,10 @@ def rebuild_col(self, model):
222
124
raise ValueError (str (e ))
223
125
# return 'rebuild success'
224
126
225
- >> >> >> > main
226
127
def delete (self , ids ) -> None :
227
128
pipe = self ._client .pipeline ()
228
129
for data_id in ids :
229
130
pipe .delete (f"{ self .doc_prefix } { data_id } " )
230
- < << << << HEAD
231
- pipe .execute ()
232
- == == == =
233
131
pipe .execute ()
234
132
235
133
def create (self , model = None ):
@@ -239,4 +137,3 @@ def create(self, model=None):
239
137
240
138
def get_index_by_name (self , index_name ):
241
139
pass
242
- > >> >> >> main
0 commit comments