47
47
48
48
model_registry = {}
49
49
_T = TypeVar ("_T" )
50
+ Model = TypeVar ("Model" , bound = "RedisModel" )
50
51
log = logging .getLogger (__name__ )
51
52
escaper = TokenEscaper ()
52
53
@@ -1160,16 +1161,16 @@ async def delete(
1160
1161
return await cls ._delete (db , cls .make_primary_key (pk ))
1161
1162
1162
1163
@classmethod
1163
- async def get (cls , pk : Any ) -> "RedisModel " :
1164
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1164
1165
raise NotImplementedError
1165
1166
1166
1167
async def update (self , ** field_values ):
1167
1168
"""Update this model instance with the specified key-value pairs."""
1168
1169
raise NotImplementedError
1169
1170
1170
1171
async def save (
1171
- self , pipeline : Optional [redis .client .Pipeline ] = None
1172
- ) -> "RedisModel " :
1172
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1173
+ ) -> "Model " :
1173
1174
raise NotImplementedError
1174
1175
1175
1176
async def expire (
@@ -1266,11 +1267,11 @@ def get_annotations(cls):
1266
1267
1267
1268
@classmethod
1268
1269
async def add (
1269
- cls ,
1270
- models : Sequence ["RedisModel " ],
1270
+ cls : Type [ "Model" ] ,
1271
+ models : Sequence ["Model " ],
1271
1272
pipeline : Optional [redis .client .Pipeline ] = None ,
1272
1273
pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1273
- ) -> Sequence ["RedisModel " ]:
1274
+ ) -> Sequence ["Model " ]:
1274
1275
db = cls ._get_db (pipeline , bulk = True )
1275
1276
1276
1277
for model in models :
@@ -1345,8 +1346,8 @@ def __init_subclass__(cls, **kwargs):
1345
1346
)
1346
1347
1347
1348
async def save (
1348
- self , pipeline : Optional [redis .client .Pipeline ] = None
1349
- ) -> "HashModel " :
1349
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1350
+ ) -> "Model " :
1350
1351
self .check ()
1351
1352
db = self ._get_db (pipeline )
1352
1353
@@ -1368,7 +1369,7 @@ async def all_pks(cls): # type: ignore
1368
1369
)
1369
1370
1370
1371
@classmethod
1371
- async def get (cls , pk : Any ) -> "HashModel " :
1372
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1372
1373
document = await cls .db ().hgetall (cls .make_primary_key (pk ))
1373
1374
if not document :
1374
1375
raise NotFoundError
@@ -1513,8 +1514,8 @@ def __init__(self, *args, **kwargs):
1513
1514
super ().__init__ (* args , ** kwargs )
1514
1515
1515
1516
async def save (
1516
- self , pipeline : Optional [redis .client .Pipeline ] = None
1517
- ) -> "JsonModel " :
1517
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1518
+ ) -> "Model " :
1518
1519
self .check ()
1519
1520
db = self ._get_db (pipeline )
1520
1521
@@ -1559,7 +1560,7 @@ async def update(self, **field_values):
1559
1560
await self .save ()
1560
1561
1561
1562
@classmethod
1562
- async def get (cls , pk : Any ) -> "JsonModel " :
1563
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1563
1564
document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
1564
1565
if document == "null" :
1565
1566
raise NotFoundError
0 commit comments