@@ -252,6 +252,24 @@ def tree(self):
252
252
return render_tree (self )
253
253
254
254
255
+ @dataclasses .dataclass
256
+ class KNNExpression :
257
+ k : int
258
+ vector_field : ModelField
259
+ reference_vector : bytes
260
+
261
+ def __str__ (self ):
262
+ return f"KNN $K @{ self .vector_field .name } $knn_ref_vector"
263
+
264
+ @property
265
+ def query_params (self ) -> Dict [str , Union [str , bytes ]]:
266
+ return {"K" : str (self .k ), "knn_ref_vector" : self .reference_vector }
267
+
268
+ @property
269
+ def score_field (self ) -> str :
270
+ return f"__{ self .vector_field .name } _score"
271
+
272
+
255
273
ExpressionOrNegated = Union [Expression , NegatedExpression ]
256
274
257
275
@@ -349,8 +367,9 @@ def __init__(
349
367
self ,
350
368
expressions : Sequence [ExpressionOrNegated ],
351
369
model : Type ["RedisModel" ],
370
+ knn : Optional [KNNExpression ] = None ,
352
371
offset : int = 0 ,
353
- limit : int = DEFAULT_PAGE_SIZE ,
372
+ limit : Optional [ int ] = None ,
354
373
page_size : int = DEFAULT_PAGE_SIZE ,
355
374
sort_fields : Optional [List [str ]] = None ,
356
375
nocontent : bool = False ,
@@ -364,13 +383,16 @@ def __init__(
364
383
365
384
self .expressions = expressions
366
385
self .model = model
386
+ self .knn = knn
367
387
self .offset = offset
368
- self .limit = limit
388
+ self .limit = limit or ( self . knn . k if self . knn else DEFAULT_PAGE_SIZE )
369
389
self .page_size = page_size
370
390
self .nocontent = nocontent
371
391
372
392
if sort_fields :
373
393
self .sort_fields = self .validate_sort_fields (sort_fields )
394
+ elif self .knn :
395
+ self .sort_fields = [self .knn .score_field ]
374
396
else :
375
397
self .sort_fields = []
376
398
@@ -425,11 +447,26 @@ def query(self):
425
447
if self ._query :
426
448
return self ._query
427
449
self ._query = self .resolve_redisearch_query (self .expression )
450
+ if self .knn :
451
+ self ._query = (
452
+ self ._query
453
+ if self ._query .startswith ("(" ) or self ._query == "*"
454
+ else f"({ self ._query } )"
455
+ ) + f"=>[{ self .knn } ]"
428
456
return self ._query
429
457
458
+ @property
459
+ def query_params (self ):
460
+ params : List [Union [str , bytes ]] = []
461
+ if self .knn :
462
+ params += [attr for kv in self .knn .query_params .items () for attr in kv ]
463
+ return params
464
+
430
465
def validate_sort_fields (self , sort_fields : List [str ]):
431
466
for sort_field in sort_fields :
432
467
field_name = sort_field .lstrip ("-" )
468
+ if self .knn and field_name == self .knn .score_field :
469
+ continue
433
470
if field_name not in self .model .__fields__ :
434
471
raise QueryNotSupportedError (
435
472
f"You tried sort by { field_name } , but that field "
@@ -728,10 +765,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
728
765
return result
729
766
730
767
async def execute (self , exhaust_results = True , return_raw_result = False ):
731
- args = ["ft.search" , self .model .Meta .index_name , self .query , * self .pagination ]
768
+ args : List [Union [str , bytes ]] = [
769
+ "FT.SEARCH" ,
770
+ self .model .Meta .index_name ,
771
+ self .query ,
772
+ * self .pagination ,
773
+ ]
732
774
if self .sort_fields :
733
775
args += self .resolve_redisearch_sort_fields ()
734
776
777
+ if self .query_params :
778
+ args += ["PARAMS" , str (len (self .query_params ))] + self .query_params
779
+
780
+ if self .knn :
781
+ # Ensure DIALECT is at least 2
782
+ if "DIALECT" not in args :
783
+ args += ["DIALECT" , "2" ]
784
+ else :
785
+ i_dialect = args .index ("DIALECT" ) + 1
786
+ if int (args [i_dialect ]) < 2 :
787
+ args [i_dialect ] = "2"
788
+
735
789
if self .nocontent :
736
790
args .append ("NOCONTENT" )
737
791
@@ -917,11 +971,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
917
971
sortable = kwargs .pop ("sortable" , Undefined )
918
972
index = kwargs .pop ("index" , Undefined )
919
973
full_text_search = kwargs .pop ("full_text_search" , Undefined )
974
+ vector_options = kwargs .pop ("vector_options" , None )
920
975
super ().__init__ (default = default , ** kwargs )
921
976
self .primary_key = primary_key
922
977
self .sortable = sortable
923
978
self .index = index
924
979
self .full_text_search = full_text_search
980
+ self .vector_options = vector_options
925
981
926
982
927
983
class RelationshipInfo (Representation ):
@@ -935,6 +991,94 @@ def __init__(
935
991
self .link_model = link_model
936
992
937
993
994
+ @dataclasses .dataclass
995
+ class VectorFieldOptions :
996
+ class ALGORITHM (Enum ):
997
+ FLAT = "FLAT"
998
+ HNSW = "HNSW"
999
+
1000
+ class TYPE (Enum ):
1001
+ FLOAT32 = "FLOAT32"
1002
+ FLOAT64 = "FLOAT64"
1003
+
1004
+ class DISTANCE_METRIC (Enum ):
1005
+ L2 = "L2"
1006
+ IP = "IP"
1007
+ COSINE = "COSINE"
1008
+
1009
+ algorithm : ALGORITHM
1010
+ type : TYPE
1011
+ dimension : int
1012
+ distance_metric : DISTANCE_METRIC
1013
+
1014
+ # Common optional parameters
1015
+ initial_cap : Optional [int ] = None
1016
+
1017
+ # Optional parameters for FLAT
1018
+ block_size : Optional [int ] = None
1019
+
1020
+ # Optional parameters for HNSW
1021
+ m : Optional [int ] = None
1022
+ ef_construction : Optional [int ] = None
1023
+ ef_runtime : Optional [int ] = None
1024
+ epsilon : Optional [float ] = None
1025
+
1026
+ @staticmethod
1027
+ def flat (
1028
+ type : TYPE ,
1029
+ dimension : int ,
1030
+ distance_metric : DISTANCE_METRIC ,
1031
+ initial_cap : Optional [int ] = None ,
1032
+ block_size : Optional [int ] = None ,
1033
+ ):
1034
+ return VectorFieldOptions (
1035
+ algorithm = VectorFieldOptions .ALGORITHM .FLAT ,
1036
+ type = type ,
1037
+ dimension = dimension ,
1038
+ distance_metric = distance_metric ,
1039
+ initial_cap = initial_cap ,
1040
+ block_size = block_size ,
1041
+ )
1042
+
1043
+ @staticmethod
1044
+ def hnsw (
1045
+ type : TYPE ,
1046
+ dimension : int ,
1047
+ distance_metric : DISTANCE_METRIC ,
1048
+ initial_cap : Optional [int ] = None ,
1049
+ m : Optional [int ] = None ,
1050
+ ef_construction : Optional [int ] = None ,
1051
+ ef_runtime : Optional [int ] = None ,
1052
+ epsilon : Optional [float ] = None ,
1053
+ ):
1054
+ return VectorFieldOptions (
1055
+ algorithm = VectorFieldOptions .ALGORITHM .HNSW ,
1056
+ type = type ,
1057
+ dimension = dimension ,
1058
+ distance_metric = distance_metric ,
1059
+ initial_cap = initial_cap ,
1060
+ m = m ,
1061
+ ef_construction = ef_construction ,
1062
+ ef_runtime = ef_runtime ,
1063
+ epsilon = epsilon ,
1064
+ )
1065
+
1066
+ @property
1067
+ def schema (self ):
1068
+ attr = []
1069
+ for k , v in vars (self ).items ():
1070
+ if k == "algorithm" or v is None :
1071
+ continue
1072
+ attr .extend (
1073
+ [
1074
+ k .upper () if k != "dimension" else "DIM" ,
1075
+ str (v ) if not isinstance (v , Enum ) else v .name ,
1076
+ ]
1077
+ )
1078
+
1079
+ return " " .join ([f"VECTOR { self .algorithm .name } { len (attr )} " ] + attr )
1080
+
1081
+
938
1082
def Field (
939
1083
default : Any = Undefined ,
940
1084
* ,
@@ -964,6 +1108,7 @@ def Field(
964
1108
sortable : Union [bool , UndefinedType ] = Undefined ,
965
1109
index : Union [bool , UndefinedType ] = Undefined ,
966
1110
full_text_search : Union [bool , UndefinedType ] = Undefined ,
1111
+ vector_options : Optional [VectorFieldOptions ] = None ,
967
1112
schema_extra : Optional [Dict [str , Any ]] = None ,
968
1113
) -> Any :
969
1114
current_schema_extra = schema_extra or {}
@@ -991,6 +1136,7 @@ def Field(
991
1136
sortable = sortable ,
992
1137
index = index ,
993
1138
full_text_search = full_text_search ,
1139
+ vector_options = vector_options ,
994
1140
** current_schema_extra ,
995
1141
)
996
1142
field_info ._validate ()
@@ -1083,6 +1229,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
1083
1229
new_class ._meta .primary_key = PrimaryKey (
1084
1230
name = field_name , field = field
1085
1231
)
1232
+ if field .field_info .vector_options :
1233
+ score_attr = f"_{ field_name } _score"
1234
+ setattr (new_class , score_attr , None )
1235
+ new_class .__annotations__ [score_attr ] = Union [float , None ]
1086
1236
1087
1237
if not getattr (new_class ._meta , "global_key_prefix" , None ):
1088
1238
new_class ._meta .global_key_prefix = getattr (
@@ -1216,8 +1366,12 @@ def db(cls):
1216
1366
return cls ._meta .database
1217
1367
1218
1368
@classmethod
1219
- def find (cls , * expressions : Union [Any , Expression ]) -> FindQuery :
1220
- return FindQuery (expressions = expressions , model = cls )
1369
+ def find (
1370
+ cls ,
1371
+ * expressions : Union [Any , Expression ],
1372
+ knn : Optional [KNNExpression ] = None ,
1373
+ ) -> FindQuery :
1374
+ return FindQuery (expressions = expressions , knn = knn , model = cls )
1221
1375
1222
1376
@classmethod
1223
1377
def from_redis (cls , res : Any ):
@@ -1237,7 +1391,7 @@ def to_string(s):
1237
1391
for i in range (1 , len (res ), step ):
1238
1392
if res [i + offset ] is None :
1239
1393
continue
1240
- fields = dict (
1394
+ fields : Dict [ str , str ] = dict (
1241
1395
zip (
1242
1396
map (to_string , res [i + offset ][::2 ]),
1243
1397
map (to_string , res [i + offset ][1 ::2 ]),
@@ -1247,6 +1401,9 @@ def to_string(s):
1247
1401
if fields .get ("$" ):
1248
1402
json_fields = json .loads (fields .pop ("$" ))
1249
1403
doc = cls (** json_fields )
1404
+ for k , v in fields .items ():
1405
+ if k .startswith ("__" ) and k .endswith ("_score" ):
1406
+ setattr (doc , k [1 :], float (v ))
1250
1407
else :
1251
1408
doc = cls (** fields )
1252
1409
@@ -1474,7 +1631,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
1474
1631
embedded_cls = embedded_cls [0 ]
1475
1632
schema = cls .schema_for_type (name , embedded_cls , field_info )
1476
1633
elif any (issubclass (typ , t ) for t in NUMERIC_TYPES ):
1477
- schema = f"{ name } NUMERIC"
1634
+ vector_options : Optional [VectorFieldOptions ] = getattr (
1635
+ field_info , "vector_options" , None
1636
+ )
1637
+ if vector_options :
1638
+ schema = f"{ name } { vector_options .schema } "
1639
+ else :
1640
+ schema = f"{ name } NUMERIC"
1478
1641
elif issubclass (typ , str ):
1479
1642
if getattr (field_info , "full_text_search" , False ) is True :
1480
1643
schema = (
@@ -1623,10 +1786,22 @@ def schema_for_type(
1623
1786
# Not a class, probably a type annotation
1624
1787
field_is_model = False
1625
1788
1789
+ vector_options : Optional [VectorFieldOptions ] = getattr (
1790
+ field_info , "vector_options" , None
1791
+ )
1792
+ try :
1793
+ is_vector = vector_options and any (
1794
+ issubclass (get_args (typ )[0 ], t ) for t in NUMERIC_TYPES
1795
+ )
1796
+ except IndexError :
1797
+ raise RedisModelError (
1798
+ f"Vector field '{ name } ' must be annotated as a container type"
1799
+ )
1800
+
1626
1801
# When we encounter a list or model field, we need to descend
1627
1802
# into the values of the list or the fields of the model to
1628
1803
# find any values marked as indexed.
1629
- if is_container_type :
1804
+ if is_container_type and not is_vector :
1630
1805
field_type = get_origin (typ )
1631
1806
embedded_cls = get_args (typ )
1632
1807
if not embedded_cls :
@@ -1689,7 +1864,9 @@ def schema_for_type(
1689
1864
)
1690
1865
1691
1866
# TODO: GEO field
1692
- if parent_is_container_type or parent_is_model_in_container :
1867
+ if is_vector and vector_options :
1868
+ schema = f"{ path } AS { index_field_name } { vector_options .schema } "
1869
+ elif parent_is_container_type or parent_is_model_in_container :
1693
1870
if typ is not str :
1694
1871
raise RedisModelError (
1695
1872
"In this Preview release, list and tuple fields can only "
0 commit comments