34
34
from pydantic .utils import Representation
35
35
from typing_extensions import Protocol , get_args , get_origin
36
36
from ulid import ULID
37
+ from more_itertools import ichunked
37
38
38
39
from ..checks import has_redis_json , has_redisearch
39
40
from ..connections import get_redis_connection
@@ -1114,14 +1115,16 @@ def key(self):
1114
1115
pk = getattr (self , self ._meta .primary_key .field .name )
1115
1116
return self .make_primary_key (pk )
1116
1117
1118
+ @classmethod
1119
+ async def _delete (cls , db , * pks ):
1120
+ return await db .delete (* pks )
1121
+
1117
1122
@classmethod
1118
1123
async def delete (cls , pk : Any , pipeline : Optional [Pipeline ] = None ) -> int :
1119
1124
"""Delete data at this key."""
1120
- if pipeline is None :
1121
- db = cls .db ()
1122
- else :
1123
- db = pipeline
1124
- return await db .delete (cls .make_primary_key (pk ))
1125
+ db = cls ._get_db (pipeline )
1126
+
1127
+ return await cls ._delete (db , cls .make_primary_key (pk ))
1125
1128
1126
1129
@classmethod
1127
1130
async def get (cls , pk : Any ) -> "RedisModel" :
@@ -1135,10 +1138,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
1135
1138
raise NotImplementedError
1136
1139
1137
1140
async def expire (self , num_seconds : int , pipeline : Optional [Pipeline ] = None ):
1138
- if pipeline is None :
1139
- db = self .db ()
1140
- else :
1141
- db = pipeline
1141
+ db = self ._get_db (pipeline )
1142
1142
1143
1143
# TODO: Wrap any Redis response errors in a custom exception?
1144
1144
await db .expire (self .make_primary_key (self .pk ), num_seconds )
@@ -1248,16 +1248,7 @@ async def add(
1248
1248
pipeline : Optional [Pipeline ] = None ,
1249
1249
pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1250
1250
) -> Sequence ["RedisModel" ]:
1251
- if pipeline is None :
1252
- # By default, send commands in a pipeline. Saving each model will
1253
- # be atomic, but Redis may process other commands in between
1254
- # these saves.
1255
- db = cls .db ().pipeline (transaction = False )
1256
- else :
1257
- # If the user gave us a pipeline, add our commands to that. The user
1258
- # will be responsible for executing the pipeline after they've accumulated
1259
- # the commands they want to send.
1260
- db = pipeline
1251
+ db = cls ._get_db (pipeline , bulk = True )
1261
1252
1262
1253
for model in models :
1263
1254
# save() just returns the model, we don't need that here.
@@ -1272,25 +1263,25 @@ async def add(
1272
1263
return models
1273
1264
1274
1265
@classmethod
1275
- async def delete_all (
1266
+ def _get_db (self , pipeline : Optional [Pipeline ]= None , bulk : bool = False ):
1267
+ if pipeline is not None :
1268
+ return pipeline
1269
+ elif bulk :
1270
+ return self .db ().pipeline (transaction = False )
1271
+ else :
1272
+ return self .db ()
1273
+
1274
+ @classmethod
1275
+ async def delete_many (
1276
1276
cls ,
1277
1277
models : Sequence ["RedisModel" ],
1278
1278
pipeline : Optional [Pipeline ] = None ,
1279
- pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1280
1279
) -> int :
1281
- if pipeline is None :
1282
- db = cls .db ().pipeline (transaction = False )
1283
- else :
1284
- db = pipeline
1285
-
1286
- for model in models :
1287
- await model .delete (model .pk , pipeline = db )
1280
+ db = cls ._get_db (pipeline )
1288
1281
1289
- # If the user didn't give us a pipeline, then we need to execute
1290
- # the one we just created.
1291
- if pipeline is None :
1292
- result = await db .execute ()
1293
- pipeline_verifier (result , expected_responses = len (models ))
1282
+ for chunk in ichunked (models , 100 ):
1283
+ pks = [cls .make_primary_key (model .pk ) for model in chunk ]
1284
+ await cls ._delete (db , * pks )
1294
1285
1295
1286
return len (models )
1296
1287
@@ -1330,10 +1321,8 @@ def __init_subclass__(cls, **kwargs):
1330
1321
1331
1322
async def save (self , pipeline : Optional [Pipeline ] = None ) -> "HashModel" :
1332
1323
self .check ()
1333
- if pipeline is None :
1334
- db = self .db ()
1335
- else :
1336
- db = pipeline
1324
+ db = self ._get_db (pipeline )
1325
+
1337
1326
document = jsonable_encoder (self .dict ())
1338
1327
# TODO: Wrap any Redis response errors in a custom exception?
1339
1328
await db .hset (self .key (), mapping = document )
@@ -1502,10 +1491,8 @@ def __init__(self, *args, **kwargs):
1502
1491
1503
1492
async def save (self , pipeline : Optional [Pipeline ] = None ) -> "JsonModel" :
1504
1493
self .check ()
1505
- if pipeline is None :
1506
- db = self .db ()
1507
- else :
1508
- db = pipeline
1494
+ db = self ._get_db (pipeline )
1495
+
1509
1496
# TODO: Wrap response errors in a custom exception?
1510
1497
await db .execute_command ("JSON.SET" , self .key (), "." , self .json ())
1511
1498
return self
0 commit comments