Skip to content

Commit 89b6c84

Browse files
Pwutschayim
andauthored
Add support for KNN vector similarity search (redis#513)
Co-authored-by: Chayim <[email protected]>
1 parent 70f6401 commit 89b6c84

File tree

3 files changed

+190
-9
lines changed

3 files changed

+190
-9
lines changed

aredis_om/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
FindQuery,
99
HashModel,
1010
JsonModel,
11+
VectorFieldOptions,
12+
KNNExpression,
1113
NotFoundError,
1214
QueryNotSupportedError,
1315
QuerySyntaxError,

aredis_om/model/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
Field,
55
HashModel,
66
JsonModel,
7+
VectorFieldOptions,
8+
KNNExpression,
79
NotFoundError,
810
RedisModel,
911
)

aredis_om/model/model.py

+186-9
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,24 @@ def tree(self):
252252
return render_tree(self)
253253

254254

255+
@dataclasses.dataclass
256+
class KNNExpression:
257+
k: int
258+
vector_field: ModelField
259+
reference_vector: bytes
260+
261+
def __str__(self):
262+
return f"KNN $K @{self.vector_field.name} $knn_ref_vector"
263+
264+
@property
265+
def query_params(self) -> Dict[str, Union[str, bytes]]:
266+
return {"K": str(self.k), "knn_ref_vector": self.reference_vector}
267+
268+
@property
269+
def score_field(self) -> str:
270+
return f"__{self.vector_field.name}_score"
271+
272+
255273
ExpressionOrNegated = Union[Expression, NegatedExpression]
256274

257275

@@ -349,8 +367,9 @@ def __init__(
349367
self,
350368
expressions: Sequence[ExpressionOrNegated],
351369
model: Type["RedisModel"],
370+
knn: Optional[KNNExpression] = None,
352371
offset: int = 0,
353-
limit: int = DEFAULT_PAGE_SIZE,
372+
limit: Optional[int] = None,
354373
page_size: int = DEFAULT_PAGE_SIZE,
355374
sort_fields: Optional[List[str]] = None,
356375
nocontent: bool = False,
@@ -364,13 +383,16 @@ def __init__(
364383

365384
self.expressions = expressions
366385
self.model = model
386+
self.knn = knn
367387
self.offset = offset
368-
self.limit = limit
388+
self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE)
369389
self.page_size = page_size
370390
self.nocontent = nocontent
371391

372392
if sort_fields:
373393
self.sort_fields = self.validate_sort_fields(sort_fields)
394+
elif self.knn:
395+
self.sort_fields = [self.knn.score_field]
374396
else:
375397
self.sort_fields = []
376398

@@ -425,11 +447,26 @@ def query(self):
425447
if self._query:
426448
return self._query
427449
self._query = self.resolve_redisearch_query(self.expression)
450+
if self.knn:
451+
self._query = (
452+
self._query
453+
if self._query.startswith("(") or self._query == "*"
454+
else f"({self._query})"
455+
) + f"=>[{self.knn}]"
428456
return self._query
429457

458+
@property
459+
def query_params(self):
460+
params: List[Union[str, bytes]] = []
461+
if self.knn:
462+
params += [attr for kv in self.knn.query_params.items() for attr in kv]
463+
return params
464+
430465
def validate_sort_fields(self, sort_fields: List[str]):
431466
for sort_field in sort_fields:
432467
field_name = sort_field.lstrip("-")
468+
if self.knn and field_name == self.knn.score_field:
469+
continue
433470
if field_name not in self.model.__fields__:
434471
raise QueryNotSupportedError(
435472
f"You tried sort by {field_name}, but that field "
@@ -728,10 +765,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
728765
return result
729766

730767
async def execute(self, exhaust_results=True, return_raw_result=False):
731-
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
768+
args: List[Union[str, bytes]] = [
769+
"FT.SEARCH",
770+
self.model.Meta.index_name,
771+
self.query,
772+
*self.pagination,
773+
]
732774
if self.sort_fields:
733775
args += self.resolve_redisearch_sort_fields()
734776

777+
if self.query_params:
778+
args += ["PARAMS", str(len(self.query_params))] + self.query_params
779+
780+
if self.knn:
781+
# Ensure DIALECT is at least 2
782+
if "DIALECT" not in args:
783+
args += ["DIALECT", "2"]
784+
else:
785+
i_dialect = args.index("DIALECT") + 1
786+
if int(args[i_dialect]) < 2:
787+
args[i_dialect] = "2"
788+
735789
if self.nocontent:
736790
args.append("NOCONTENT")
737791

@@ -917,11 +971,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
917971
sortable = kwargs.pop("sortable", Undefined)
918972
index = kwargs.pop("index", Undefined)
919973
full_text_search = kwargs.pop("full_text_search", Undefined)
974+
vector_options = kwargs.pop("vector_options", None)
920975
super().__init__(default=default, **kwargs)
921976
self.primary_key = primary_key
922977
self.sortable = sortable
923978
self.index = index
924979
self.full_text_search = full_text_search
980+
self.vector_options = vector_options
925981

926982

927983
class RelationshipInfo(Representation):
@@ -935,6 +991,94 @@ def __init__(
935991
self.link_model = link_model
936992

937993

994+
@dataclasses.dataclass
995+
class VectorFieldOptions:
996+
class ALGORITHM(Enum):
997+
FLAT = "FLAT"
998+
HNSW = "HNSW"
999+
1000+
class TYPE(Enum):
1001+
FLOAT32 = "FLOAT32"
1002+
FLOAT64 = "FLOAT64"
1003+
1004+
class DISTANCE_METRIC(Enum):
1005+
L2 = "L2"
1006+
IP = "IP"
1007+
COSINE = "COSINE"
1008+
1009+
algorithm: ALGORITHM
1010+
type: TYPE
1011+
dimension: int
1012+
distance_metric: DISTANCE_METRIC
1013+
1014+
# Common optional parameters
1015+
initial_cap: Optional[int] = None
1016+
1017+
# Optional parameters for FLAT
1018+
block_size: Optional[int] = None
1019+
1020+
# Optional parameters for HNSW
1021+
m: Optional[int] = None
1022+
ef_construction: Optional[int] = None
1023+
ef_runtime: Optional[int] = None
1024+
epsilon: Optional[float] = None
1025+
1026+
@staticmethod
1027+
def flat(
1028+
type: TYPE,
1029+
dimension: int,
1030+
distance_metric: DISTANCE_METRIC,
1031+
initial_cap: Optional[int] = None,
1032+
block_size: Optional[int] = None,
1033+
):
1034+
return VectorFieldOptions(
1035+
algorithm=VectorFieldOptions.ALGORITHM.FLAT,
1036+
type=type,
1037+
dimension=dimension,
1038+
distance_metric=distance_metric,
1039+
initial_cap=initial_cap,
1040+
block_size=block_size,
1041+
)
1042+
1043+
@staticmethod
1044+
def hnsw(
1045+
type: TYPE,
1046+
dimension: int,
1047+
distance_metric: DISTANCE_METRIC,
1048+
initial_cap: Optional[int] = None,
1049+
m: Optional[int] = None,
1050+
ef_construction: Optional[int] = None,
1051+
ef_runtime: Optional[int] = None,
1052+
epsilon: Optional[float] = None,
1053+
):
1054+
return VectorFieldOptions(
1055+
algorithm=VectorFieldOptions.ALGORITHM.HNSW,
1056+
type=type,
1057+
dimension=dimension,
1058+
distance_metric=distance_metric,
1059+
initial_cap=initial_cap,
1060+
m=m,
1061+
ef_construction=ef_construction,
1062+
ef_runtime=ef_runtime,
1063+
epsilon=epsilon,
1064+
)
1065+
1066+
@property
1067+
def schema(self):
1068+
attr = []
1069+
for k, v in vars(self).items():
1070+
if k == "algorithm" or v is None:
1071+
continue
1072+
attr.extend(
1073+
[
1074+
k.upper() if k != "dimension" else "DIM",
1075+
str(v) if not isinstance(v, Enum) else v.name,
1076+
]
1077+
)
1078+
1079+
return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr)
1080+
1081+
9381082
def Field(
9391083
default: Any = Undefined,
9401084
*,
@@ -964,6 +1108,7 @@ def Field(
9641108
sortable: Union[bool, UndefinedType] = Undefined,
9651109
index: Union[bool, UndefinedType] = Undefined,
9661110
full_text_search: Union[bool, UndefinedType] = Undefined,
1111+
vector_options: Optional[VectorFieldOptions] = None,
9671112
schema_extra: Optional[Dict[str, Any]] = None,
9681113
) -> Any:
9691114
current_schema_extra = schema_extra or {}
@@ -991,6 +1136,7 @@ def Field(
9911136
sortable=sortable,
9921137
index=index,
9931138
full_text_search=full_text_search,
1139+
vector_options=vector_options,
9941140
**current_schema_extra,
9951141
)
9961142
field_info._validate()
@@ -1083,6 +1229,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
10831229
new_class._meta.primary_key = PrimaryKey(
10841230
name=field_name, field=field
10851231
)
1232+
if field.field_info.vector_options:
1233+
score_attr = f"_{field_name}_score"
1234+
setattr(new_class, score_attr, None)
1235+
new_class.__annotations__[score_attr] = Union[float, None]
10861236

10871237
if not getattr(new_class._meta, "global_key_prefix", None):
10881238
new_class._meta.global_key_prefix = getattr(
@@ -1216,8 +1366,12 @@ def db(cls):
12161366
return cls._meta.database
12171367

12181368
@classmethod
1219-
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
1220-
return FindQuery(expressions=expressions, model=cls)
1369+
def find(
1370+
cls,
1371+
*expressions: Union[Any, Expression],
1372+
knn: Optional[KNNExpression] = None,
1373+
) -> FindQuery:
1374+
return FindQuery(expressions=expressions, knn=knn, model=cls)
12211375

12221376
@classmethod
12231377
def from_redis(cls, res: Any):
@@ -1237,7 +1391,7 @@ def to_string(s):
12371391
for i in range(1, len(res), step):
12381392
if res[i + offset] is None:
12391393
continue
1240-
fields = dict(
1394+
fields: Dict[str, str] = dict(
12411395
zip(
12421396
map(to_string, res[i + offset][::2]),
12431397
map(to_string, res[i + offset][1::2]),
@@ -1247,6 +1401,9 @@ def to_string(s):
12471401
if fields.get("$"):
12481402
json_fields = json.loads(fields.pop("$"))
12491403
doc = cls(**json_fields)
1404+
for k, v in fields.items():
1405+
if k.startswith("__") and k.endswith("_score"):
1406+
setattr(doc, k[1:], float(v))
12501407
else:
12511408
doc = cls(**fields)
12521409

@@ -1474,7 +1631,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
14741631
embedded_cls = embedded_cls[0]
14751632
schema = cls.schema_for_type(name, embedded_cls, field_info)
14761633
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
1477-
schema = f"{name} NUMERIC"
1634+
vector_options: Optional[VectorFieldOptions] = getattr(
1635+
field_info, "vector_options", None
1636+
)
1637+
if vector_options:
1638+
schema = f"{name} {vector_options.schema}"
1639+
else:
1640+
schema = f"{name} NUMERIC"
14781641
elif issubclass(typ, str):
14791642
if getattr(field_info, "full_text_search", False) is True:
14801643
schema = (
@@ -1623,10 +1786,22 @@ def schema_for_type(
16231786
# Not a class, probably a type annotation
16241787
field_is_model = False
16251788

1789+
vector_options: Optional[VectorFieldOptions] = getattr(
1790+
field_info, "vector_options", None
1791+
)
1792+
try:
1793+
is_vector = vector_options and any(
1794+
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
1795+
)
1796+
except IndexError:
1797+
raise RedisModelError(
1798+
f"Vector field '{name}' must be annotated as a container type"
1799+
)
1800+
16261801
# When we encounter a list or model field, we need to descend
16271802
# into the values of the list or the fields of the model to
16281803
# find any values marked as indexed.
1629-
if is_container_type:
1804+
if is_container_type and not is_vector:
16301805
field_type = get_origin(typ)
16311806
embedded_cls = get_args(typ)
16321807
if not embedded_cls:
@@ -1689,7 +1864,9 @@ def schema_for_type(
16891864
)
16901865

16911866
# TODO: GEO field
1692-
if parent_is_container_type or parent_is_model_in_container:
1867+
if is_vector and vector_options:
1868+
schema = f"{path} AS {index_field_name} {vector_options.schema}"
1869+
elif parent_is_container_type or parent_is_model_in_container:
16931870
if typ is not str:
16941871
raise RedisModelError(
16951872
"In this Preview release, list and tuple fields can only "

0 commit comments

Comments
 (0)