Skip to content

Commit 2df1beb

Browse files
author
savynorem
committed
added return_fields function, attempting to optionally limit fields returned by find
1 parent 44cbeaf commit 2df1beb

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

aredis_om/model/model.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def __init__(
421421
limit: Optional[int] = None,
422422
page_size: int = DEFAULT_PAGE_SIZE,
423423
sort_fields: Optional[List[str]] = None,
424+
return_fields: Optional[List[str]] = None,
424425
nocontent: bool = False,
425426
):
426427
if not has_redisearch(model.db()):
@@ -445,6 +446,11 @@ def __init__(
445446
else:
446447
self.sort_fields = []
447448

449+
if return_fields:
450+
self.return_fields = self.validate_return_fields(return_fields)
451+
else:
452+
self.return_fields = []
453+
448454
self._expression = None
449455
self._query: Optional[str] = None
450456
self._pagination: List[str] = []
@@ -502,8 +508,19 @@ def query(self):
502508
if self._query.startswith("(") or self._query == "*"
503509
else f"({self._query})"
504510
) + f"=>[{self.knn}]"
511+
if self.return_fields:
512+
self._query += f" RETURN {','.join(self.return_fields)}"
505513
return self._query
506514

515+
def validate_return_fields(self, return_fields: List[str]):
516+
for field in return_fields:
517+
if field not in self.model.__fields__: # type: ignore
518+
raise QueryNotSupportedError(
519+
f"You tried to return the field {field}, but that field "
520+
f"does not exist on the model {self.model}"
521+
)
522+
return return_fields
523+
507524
@property
508525
def query_params(self):
509526
params: List[Union[str, bytes]] = []
@@ -956,6 +973,11 @@ def sort_by(self, *fields: str):
956973
if not fields:
957974
return self
958975
return self.copy(sort_fields=list(fields))
976+
977+
def return_fields(self, *fields: str):
978+
if not fields:
979+
return self
980+
return self.copy(return_fields=list(fields))
959981

960982
async def update(self, use_transaction=True, **field_values):
961983
"""
@@ -1531,7 +1553,9 @@ def find(
15311553
*expressions: Union[Any, Expression],
15321554
knn: Optional[KNNExpression] = None,
15331555
) -> FindQuery:
1534-
return FindQuery(expressions=expressions, knn=knn, model=cls)
1556+
return FindQuery(
1557+
expressions=expressions, knn=knn, model=cls
1558+
)
15351559

15361560
@classmethod
15371561
def from_redis(cls, res: Any):

tests/test_json_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,9 +935,21 @@ class TypeWithUuid(JsonModel):
935935

936936
await item.save()
937937

938+
@py_test_mark_asyncio
939+
async def test_return_specified_fields(members, m):
940+
member1, member2, member3 = members
941+
actual = await m.Member.find(
942+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
943+
| (m.Member.last_name == "Smith")
944+
).all()
945+
assert actual == [
946+
{"first_name": "Andrew", "last_name": "Brookins"},
947+
{"first_name": "Andrew", "last_name": "Smith"},
948+
]
949+
938950

939951
@py_test_mark_asyncio
940-
async def test_xfix_queries(m):
952+
async def test_xfix_queries(m):4
941953
await m.Member(
942954
first_name="Steve",
943955
last_name="Lorello",

0 commit comments

Comments
 (0)