Skip to content

Commit f77c21a

Browse files
authored
Add support for count (#397)
* add count * update poetry
1 parent 1221efd commit f77c21a

File tree

4 files changed

+97
-56
lines changed

4 files changed

+97
-56
lines changed

Diff for: aredis_om/model/model.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def __init__(
345345
limit: int = DEFAULT_PAGE_SIZE,
346346
page_size: int = DEFAULT_PAGE_SIZE,
347347
sort_fields: Optional[List[str]] = None,
348+
nocontent: bool = False,
348349
):
349350
if not has_redisearch(model.db()):
350351
raise RedisModelError(
@@ -358,6 +359,7 @@ def __init__(
358359
self.offset = offset
359360
self.limit = limit
360361
self.page_size = page_size
362+
self.nocontent = nocontent
361363

362364
if sort_fields:
363365
self.sort_fields = self.validate_sort_fields(sort_fields)
@@ -377,6 +379,7 @@ def dict(self) -> Dict[str, Any]:
377379
limit=self.limit,
378380
expressions=copy(self.expressions),
379381
sort_fields=copy(self.sort_fields),
382+
nocontent=self.nocontent,
380383
)
381384

382385
def copy(self, **kwargs):
@@ -716,18 +719,23 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
716719

717720
return result
718721

719-
async def execute(self, exhaust_results=True):
722+
async def execute(self, exhaust_results=True, return_raw_result=False):
720723
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
721724
if self.sort_fields:
722725
args += self.resolve_redisearch_sort_fields()
723726

727+
if self.nocontent:
728+
args.append("NOCONTENT")
729+
724730
# Reset the cache if we're executing from offset 0.
725731
if self.offset == 0:
726732
self._model_cache.clear()
727733

728734
# If the offset is greater than 0, we're paginating through a result set,
729735
# so append the new results to results already in the cache.
730736
raw_result = await self.model.db().execute_command(*args)
737+
if return_raw_result:
738+
return raw_result
731739
count = raw_result[0]
732740
results = self.model.from_redis(raw_result)
733741
self._model_cache += results
@@ -759,6 +767,11 @@ async def first(self):
759767
raise NotFoundError()
760768
return results[0]
761769

770+
async def count(self):
771+
query = self.copy(offset=0, limit=0, nocontent=True)
772+
result = await query.execute(exhaust_results=True, return_raw_result=True)
773+
return result[0]
774+
762775
async def all(self, batch_size=DEFAULT_PAGE_SIZE):
763776
if batch_size != self.page_size:
764777
query = self.copy(page_size=batch_size, limit=batch_size)
@@ -1175,7 +1188,7 @@ def validate_primary_key(cls):
11751188
if primary_keys == 0:
11761189
raise RedisModelError("You must define a primary key for the model")
11771190
elif primary_keys == 2:
1178-
cls.__fields__.pop('pk')
1191+
cls.__fields__.pop("pk")
11791192
elif primary_keys > 2:
11801193
raise RedisModelError("You must define only one primary key for a model")
11811194

Diff for: poetry.lock

+45-47
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: tests/test_hash_model.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,6 @@ class Address(m.BaseHashModel):
706706

707707
@py_test_mark_asyncio
708708
async def test_primary_key_model_error(m):
709-
710709
class Customer(m.BaseHashModel):
711710
id: int = Field(primary_key=True, index=True)
712711
first_name: str = Field(primary_key=True, index=True)
@@ -715,18 +714,19 @@ class Customer(m.BaseHashModel):
715714

716715
await Migrator().run()
717716

718-
with pytest.raises(RedisModelError, match="You must define only one primary key for a model"):
717+
with pytest.raises(
718+
RedisModelError, match="You must define only one primary key for a model"
719+
):
719720
_ = Customer(
720721
id=0,
721722
first_name="Mahmoud",
722723
last_name="Harmouch",
723-
bio="Python developer, wanna work at Redis, Inc."
724+
bio="Python developer, wanna work at Redis, Inc.",
724725
)
725726

726727

727728
@py_test_mark_asyncio
728729
async def test_primary_pk_exists(m):
729-
730730
class Customer1(m.BaseHashModel):
731731
id: int
732732
first_name: str
@@ -745,10 +745,10 @@ class Customer2(m.BaseHashModel):
745745
id=0,
746746
first_name="Mahmoud",
747747
last_name="Harmouch",
748-
bio="Python developer, wanna work at Redis, Inc."
748+
bio="Python developer, wanna work at Redis, Inc.",
749749
)
750750

751-
assert 'pk' in customer.__fields__
751+
assert "pk" in customer.__fields__
752752

753753
customer = Customer2(
754754
id=1,
@@ -757,4 +757,19 @@ class Customer2(m.BaseHashModel):
757757
bio="This is member 2 who can be quite anxious until you get to know them.",
758758
)
759759

760-
assert 'pk' not in customer.__fields__
760+
assert "pk" not in customer.__fields__
761+
762+
763+
@py_test_mark_asyncio
764+
async def test_count(members, m):
765+
# member1, member2, member3 = members
766+
actual_count = await m.Member.find(
767+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
768+
| (m.Member.last_name == "Smith")
769+
).count()
770+
assert actual_count == 2
771+
772+
actual_count = await m.Member.find(
773+
m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
774+
).count()
775+
assert actual_count == 1

0 commit comments

Comments
 (0)