Skip to content

Commit a00a68b

Browse files
dvora-hchayim
andauthored
Add delete_many to support for bulk deletes (#305)
* Add support for bulk deletes * linters * linters * fix review comments * update more-itertools version * poetry fix - maybe? * merge main & add more-itertools 8.14.0 * update poetry.lock * linters * fix test Co-authored-by: Chayim I. Kirshen <[email protected]>
1 parent 4661459 commit a00a68b

File tree

5 files changed

+171
-395
lines changed

5 files changed

+171
-395
lines changed

aredis_om/model/model.py

+42-24
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
no_type_check,
2525
)
2626

27+
from more_itertools import ichunked
2728
from pydantic import BaseModel, validator
2829
from pydantic.fields import FieldInfo as PydanticFieldInfo
2930
from pydantic.fields import ModelField, Undefined, UndefinedType
@@ -1117,9 +1118,17 @@ def key(self):
11171118
return self.make_primary_key(pk)
11181119

11191120
@classmethod
1120-
async def delete(cls, pk: Any) -> int:
1121+
async def _delete(cls, db, *pks):
1122+
return await db.delete(*pks)
1123+
1124+
@classmethod
1125+
async def delete(
1126+
cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None
1127+
) -> int:
11211128
"""Delete data at this key."""
1122-
return await cls.db().delete(cls.make_primary_key(pk))
1129+
db = cls._get_db(pipeline)
1130+
1131+
return await cls._delete(db, cls.make_primary_key(pk))
11231132

11241133
@classmethod
11251134
async def get(cls, pk: Any) -> "RedisModel":
@@ -1137,10 +1146,7 @@ async def save(
11371146
async def expire(
11381147
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
11391148
):
1140-
if pipeline is None:
1141-
db = self.db()
1142-
else:
1143-
db = pipeline
1149+
db = self._get_db(pipeline)
11441150

11451151
# TODO: Wrap any Redis response errors in a custom exception?
11461152
await db.expire(self.make_primary_key(self.pk), num_seconds)
@@ -1232,16 +1238,7 @@ async def add(
12321238
pipeline: Optional[redis.client.Pipeline] = None,
12331239
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12341240
) -> Sequence["RedisModel"]:
1235-
if pipeline is None:
1236-
# By default, send commands in a pipeline. Saving each model will
1237-
# be atomic, but Redis may process other commands in between
1238-
# these saves.
1239-
db = cls.db().pipeline(transaction=False)
1240-
else:
1241-
# If the user gave us a pipeline, add our commands to that. The user
1242-
# will be responsible for executing the pipeline after they've accumulated
1243-
# the commands they want to send.
1244-
db = pipeline
1241+
db = cls._get_db(pipeline, bulk=True)
12451242

12461243
for model in models:
12471244
# save() just returns the model, we don't need that here.
@@ -1255,6 +1252,31 @@ async def add(
12551252

12561253
return models
12571254

1255+
@classmethod
1256+
def _get_db(
1257+
self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False
1258+
):
1259+
if pipeline is not None:
1260+
return pipeline
1261+
elif bulk:
1262+
return self.db().pipeline(transaction=False)
1263+
else:
1264+
return self.db()
1265+
1266+
@classmethod
1267+
async def delete_many(
1268+
cls,
1269+
models: Sequence["RedisModel"],
1270+
pipeline: Optional[redis.client.Pipeline] = None,
1271+
) -> int:
1272+
db = cls._get_db(pipeline)
1273+
1274+
for chunk in ichunked(models, 100):
1275+
pks = [cls.make_primary_key(model.pk) for model in chunk]
1276+
await cls._delete(db, *pks)
1277+
1278+
return len(models)
1279+
12581280
@classmethod
12591281
def redisearch_schema(cls):
12601282
raise NotImplementedError
@@ -1293,10 +1315,8 @@ async def save(
12931315
self, pipeline: Optional[redis.client.Pipeline] = None
12941316
) -> "HashModel":
12951317
self.check()
1296-
if pipeline is None:
1297-
db = self.db()
1298-
else:
1299-
db = pipeline
1318+
db = self._get_db(pipeline)
1319+
13001320
document = jsonable_encoder(self.dict())
13011321
# TODO: Wrap any Redis response errors in a custom exception?
13021322
await db.hset(self.key(), mapping=document)
@@ -1467,10 +1487,8 @@ async def save(
14671487
self, pipeline: Optional[redis.client.Pipeline] = None
14681488
) -> "JsonModel":
14691489
self.check()
1470-
if pipeline is None:
1471-
db = self.db()
1472-
else:
1473-
db = pipeline
1490+
db = self._get_db(pipeline)
1491+
14741492
# TODO: Wrap response errors in a custom exception?
14751493
await db.execute_command("JSON.SET", self.key(), ".", self.json())
14761494
return self

0 commit comments

Comments
 (0)