47
47
48
48
model_registry = {}
49
49
_T = TypeVar ("_T" )
50
+ Model = TypeVar ("Model" , bound = "RedisModel" )
50
51
log = logging .getLogger (__name__ )
51
52
escaper = TokenEscaper ()
52
53
@@ -1310,16 +1311,16 @@ async def delete(
1310
1311
return await cls ._delete (db , cls .make_primary_key (pk ))
1311
1312
1312
1313
@classmethod
1313
- async def get (cls , pk : Any ) -> "RedisModel " :
1314
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1314
1315
raise NotImplementedError
1315
1316
1316
1317
async def update (self , ** field_values ):
1317
1318
"""Update this model instance with the specified key-value pairs."""
1318
1319
raise NotImplementedError
1319
1320
1320
1321
async def save (
1321
- self , pipeline : Optional [redis .client .Pipeline ] = None
1322
- ) -> "RedisModel " :
1322
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1323
+ ) -> "Model " :
1323
1324
raise NotImplementedError
1324
1325
1325
1326
async def expire (
@@ -1423,11 +1424,11 @@ def get_annotations(cls):
1423
1424
1424
1425
@classmethod
1425
1426
async def add (
1426
- cls ,
1427
- models : Sequence ["RedisModel " ],
1427
+ cls : Type [ "Model" ] ,
1428
+ models : Sequence ["Model " ],
1428
1429
pipeline : Optional [redis .client .Pipeline ] = None ,
1429
1430
pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1430
- ) -> Sequence ["RedisModel " ]:
1431
+ ) -> Sequence ["Model " ]:
1431
1432
db = cls ._get_db (pipeline , bulk = True )
1432
1433
1433
1434
for model in models :
@@ -1502,8 +1503,8 @@ def __init_subclass__(cls, **kwargs):
1502
1503
)
1503
1504
1504
1505
async def save (
1505
- self , pipeline : Optional [redis .client .Pipeline ] = None
1506
- ) -> "HashModel " :
1506
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1507
+ ) -> "Model " :
1507
1508
self .check ()
1508
1509
db = self ._get_db (pipeline )
1509
1510
@@ -1525,7 +1526,7 @@ async def all_pks(cls): # type: ignore
1525
1526
)
1526
1527
1527
1528
@classmethod
1528
- async def get (cls , pk : Any ) -> "HashModel " :
1529
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1529
1530
document = await cls .db ().hgetall (cls .make_primary_key (pk ))
1530
1531
if not document :
1531
1532
raise NotFoundError
@@ -1676,8 +1677,8 @@ def __init__(self, *args, **kwargs):
1676
1677
super ().__init__ (* args , ** kwargs )
1677
1678
1678
1679
async def save (
1679
- self , pipeline : Optional [redis .client .Pipeline ] = None
1680
- ) -> "JsonModel " :
1680
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1681
+ ) -> "Model " :
1681
1682
self .check ()
1682
1683
db = self ._get_db (pipeline )
1683
1684
@@ -1722,7 +1723,7 @@ async def update(self, **field_values):
1722
1723
await self .save ()
1723
1724
1724
1725
@classmethod
1725
- async def get (cls , pk : Any ) -> "JsonModel " :
1726
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1726
1727
document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
1727
1728
if document == "null" :
1728
1729
raise NotFoundError
0 commit comments