Skip to content

Commit ee80f7c

Browse files
committed
fix review comments
1 parent ba7d552 commit ee80f7c

File tree

4 files changed

+31
-69
lines changed

4 files changed

+31
-69
lines changed

aredis_om/model/model.py

+28-41
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pydantic.utils import Representation
3535
from typing_extensions import Protocol, get_args, get_origin
3636
from ulid import ULID
37+
from more_itertools import ichunked
3738

3839
from ..checks import has_redis_json, has_redisearch
3940
from ..connections import get_redis_connection
@@ -1114,14 +1115,16 @@ def key(self):
11141115
pk = getattr(self, self._meta.primary_key.field.name)
11151116
return self.make_primary_key(pk)
11161117

1118+
@classmethod
1119+
async def _delete(cls, db, *pks):
1120+
return await db.delete(*pks)
1121+
11171122
@classmethod
11181123
async def delete(cls, pk: Any, pipeline: Optional[Pipeline] = None) -> int:
11191124
"""Delete data at this key."""
1120-
if pipeline is None:
1121-
db = cls.db()
1122-
else:
1123-
db = pipeline
1124-
return await db.delete(cls.make_primary_key(pk))
1125+
db = cls._get_db(pipeline)
1126+
1127+
return await cls._delete(db, cls.make_primary_key(pk))
11251128

11261129
@classmethod
11271130
async def get(cls, pk: Any) -> "RedisModel":
@@ -1135,10 +1138,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
11351138
raise NotImplementedError
11361139

11371140
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
1138-
if pipeline is None:
1139-
db = self.db()
1140-
else:
1141-
db = pipeline
1141+
db = self._get_db(pipeline)
11421142

11431143
# TODO: Wrap any Redis response errors in a custom exception?
11441144
await db.expire(self.make_primary_key(self.pk), num_seconds)
@@ -1248,16 +1248,7 @@ async def add(
12481248
pipeline: Optional[Pipeline] = None,
12491249
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12501250
) -> Sequence["RedisModel"]:
1251-
if pipeline is None:
1252-
# By default, send commands in a pipeline. Saving each model will
1253-
# be atomic, but Redis may process other commands in between
1254-
# these saves.
1255-
db = cls.db().pipeline(transaction=False)
1256-
else:
1257-
# If the user gave us a pipeline, add our commands to that. The user
1258-
# will be responsible for executing the pipeline after they've accumulated
1259-
# the commands they want to send.
1260-
db = pipeline
1251+
db = cls._get_db(pipeline, bulk=True)
12611252

12621253
for model in models:
12631254
# save() just returns the model, we don't need that here.
@@ -1272,25 +1263,25 @@ async def add(
12721263
return models
12731264

12741265
@classmethod
1275-
async def delete_all(
1266+
def _get_db(self, pipeline: Optional[Pipeline]=None, bulk: bool=False):
1267+
if pipeline is not None:
1268+
return pipeline
1269+
elif bulk:
1270+
return self.db().pipeline(transaction=False)
1271+
else:
1272+
return self.db()
1273+
1274+
@classmethod
1275+
async def delete_many(
12761276
cls,
12771277
models: Sequence["RedisModel"],
12781278
pipeline: Optional[Pipeline] = None,
1279-
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12801279
) -> int:
1281-
if pipeline is None:
1282-
db = cls.db().pipeline(transaction=False)
1283-
else:
1284-
db = pipeline
1285-
1286-
for model in models:
1287-
await model.delete(model.pk, pipeline=db)
1280+
db = cls._get_db(pipeline)
12881281

1289-
# If the user didn't give us a pipeline, then we need to execute
1290-
# the one we just created.
1291-
if pipeline is None:
1292-
result = await db.execute()
1293-
pipeline_verifier(result, expected_responses=len(models))
1282+
for chunk in ichunked(models, 100):
1283+
pks = [cls.make_primary_key(model.pk) for model in chunk]
1284+
await cls._delete(db, *pks)
12941285

12951286
return len(models)
12961287

@@ -1330,10 +1321,8 @@ def __init_subclass__(cls, **kwargs):
13301321

13311322
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
13321323
self.check()
1333-
if pipeline is None:
1334-
db = self.db()
1335-
else:
1336-
db = pipeline
1324+
db = self._get_db(pipeline)
1325+
13371326
document = jsonable_encoder(self.dict())
13381327
# TODO: Wrap any Redis response errors in a custom exception?
13391328
await db.hset(self.key(), mapping=document)
@@ -1502,10 +1491,8 @@ def __init__(self, *args, **kwargs):
15021491

15031492
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
15041493
self.check()
1505-
if pipeline is None:
1506-
db = self.db()
1507-
else:
1508-
db = pipeline
1494+
db = self._get_db(pipeline)
1495+
15091496
# TODO: Wrap response errors in a custom exception?
15101497
await db.execute_command("JSON.SET", self.key(), ".", self.json())
15111498
return self

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ python-ulid = "^1.0.3"
3535
cleo = "1.0.0a4"
3636
typing-extensions = "^4.0.0"
3737
hiredis = "^2.0.0"
38+
more-itertools = "^8.13.0"
3839

3940
[tool.poetry.dev-dependencies]
4041
mypy = "^0.950"

tests/test_hash_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ async def test_delete_many(m):
574574
members = [member1, member2]
575575
result = await m.Member.add(members)
576576
assert result == [member1, member2]
577-
result = await m.Member.delete_all(members)
577+
result = await m.Member.delete_many(members)
578578
assert result == 2
579579
with pytest.raises(NotFoundError):
580580
await m.Member.get(pk=member1.pk)

tests/test_json_model.py

+1-27
Original file line numberDiff line numberDiff line change
@@ -318,38 +318,12 @@ async def test_delete_many_implicit_pipeline(address, m):
318318
members = [member1, member2]
319319
result = await m.Member.add(members)
320320
assert result == [member1, member2]
321-
result = await m.Member.delete_all(members)
321+
result = await m.Member.delete_many(members)
322322
assert result == 2
323323
with pytest.raises(NotFoundError):
324324
await m.Member.get(pk=member2.pk)
325325

326326

327-
@py_test_mark_asyncio
328-
async def test_delete_many_explicit_transaction(address, m):
329-
member1 = m.Member(
330-
first_name="Andrew",
331-
last_name="Brookins",
332-
333-
join_date=today,
334-
address=address,
335-
age=38,
336-
)
337-
member2 = m.Member(
338-
first_name="Kim",
339-
last_name="Brookins",
340-
341-
join_date=today,
342-
address=address,
343-
age=34,
344-
)
345-
members = [member1, member2]
346-
result = await m.Member.add(members)
347-
assert result == [member1, member2]
348-
async with m.Member.db().pipeline(transaction=True) as pipeline:
349-
await m.Member.delete_all(members, pipeline=pipeline)
350-
assert await pipeline.execute() == [1, 1]
351-
352-
353327
async def save(members):
354328
for m in members:
355329
await m.save()

0 commit comments

Comments
 (0)