Skip to content

Commit c68adac

Browse files
marianhlavacchayim
andauthored
Use TypeVars for return types of RedisModel and its subtype's methods (#476)
Co-authored-by: Chayim <[email protected]>
1 parent 89b6c84 commit c68adac

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

aredis_om/model/model.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
model_registry = {}
4949
_T = TypeVar("_T")
50+
Model = TypeVar("Model", bound="RedisModel")
5051
log = logging.getLogger(__name__)
5152
escaper = TokenEscaper()
5253

@@ -1310,16 +1311,16 @@ async def delete(
13101311
return await cls._delete(db, cls.make_primary_key(pk))
13111312

13121313
@classmethod
1313-
async def get(cls, pk: Any) -> "RedisModel":
1314+
async def get(cls: Type["Model"], pk: Any) -> "Model":
13141315
raise NotImplementedError
13151316

13161317
async def update(self, **field_values):
13171318
"""Update this model instance with the specified key-value pairs."""
13181319
raise NotImplementedError
13191320

13201321
async def save(
1321-
self, pipeline: Optional[redis.client.Pipeline] = None
1322-
) -> "RedisModel":
1322+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1323+
) -> "Model":
13231324
raise NotImplementedError
13241325

13251326
async def expire(
@@ -1423,11 +1424,11 @@ def get_annotations(cls):
14231424

14241425
@classmethod
14251426
async def add(
1426-
cls,
1427-
models: Sequence["RedisModel"],
1427+
cls: Type["Model"],
1428+
models: Sequence["Model"],
14281429
pipeline: Optional[redis.client.Pipeline] = None,
14291430
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
1430-
) -> Sequence["RedisModel"]:
1431+
) -> Sequence["Model"]:
14311432
db = cls._get_db(pipeline, bulk=True)
14321433

14331434
for model in models:
@@ -1502,8 +1503,8 @@ def __init_subclass__(cls, **kwargs):
15021503
)
15031504

15041505
async def save(
1505-
self, pipeline: Optional[redis.client.Pipeline] = None
1506-
) -> "HashModel":
1506+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1507+
) -> "Model":
15071508
self.check()
15081509
db = self._get_db(pipeline)
15091510

@@ -1525,7 +1526,7 @@ async def all_pks(cls): # type: ignore
15251526
)
15261527

15271528
@classmethod
1528-
async def get(cls, pk: Any) -> "HashModel":
1529+
async def get(cls: Type["Model"], pk: Any) -> "Model":
15291530
document = await cls.db().hgetall(cls.make_primary_key(pk))
15301531
if not document:
15311532
raise NotFoundError
@@ -1676,8 +1677,8 @@ def __init__(self, *args, **kwargs):
16761677
super().__init__(*args, **kwargs)
16771678

16781679
async def save(
1679-
self, pipeline: Optional[redis.client.Pipeline] = None
1680-
) -> "JsonModel":
1680+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1681+
) -> "Model":
16811682
self.check()
16821683
db = self._get_db(pipeline)
16831684

@@ -1722,7 +1723,7 @@ async def update(self, **field_values):
17221723
await self.save()
17231724

17241725
@classmethod
1725-
async def get(cls, pk: Any) -> "JsonModel":
1726+
async def get(cls: Type["Model"], pk: Any) -> "Model":
17261727
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
17271728
if document == "null":
17281729
raise NotFoundError

0 commit comments

Comments
 (0)