Skip to content

Commit d73072a

Browse files
committed
Use TypeVars for return types of RedisModel and its subtype's methods
1 parent 4ee61cb commit d73072a

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

aredis_om/model/model.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
model_registry = {}
4949
_T = TypeVar("_T")
50+
Model = TypeVar("Model", bound="RedisModel")
5051
log = logging.getLogger(__name__)
5152
escaper = TokenEscaper()
5253

@@ -1160,16 +1161,16 @@ async def delete(
11601161
return await cls._delete(db, cls.make_primary_key(pk))
11611162

11621163
@classmethod
1163-
async def get(cls, pk: Any) -> "RedisModel":
1164+
async def get(cls: Type["Model"], pk: Any) -> "Model":
11641165
raise NotImplementedError
11651166

11661167
async def update(self, **field_values):
11671168
"""Update this model instance with the specified key-value pairs."""
11681169
raise NotImplementedError
11691170

11701171
async def save(
1171-
self, pipeline: Optional[redis.client.Pipeline] = None
1172-
) -> "RedisModel":
1172+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1173+
) -> "Model":
11731174
raise NotImplementedError
11741175

11751176
async def expire(
@@ -1266,11 +1267,11 @@ def get_annotations(cls):
12661267

12671268
@classmethod
12681269
async def add(
1269-
cls,
1270-
models: Sequence["RedisModel"],
1270+
cls: Type["Model"],
1271+
models: Sequence["Model"],
12711272
pipeline: Optional[redis.client.Pipeline] = None,
12721273
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
1273-
) -> Sequence["RedisModel"]:
1274+
) -> Sequence["Model"]:
12741275
db = cls._get_db(pipeline, bulk=True)
12751276

12761277
for model in models:
@@ -1345,8 +1346,8 @@ def __init_subclass__(cls, **kwargs):
13451346
)
13461347

13471348
async def save(
1348-
self, pipeline: Optional[redis.client.Pipeline] = None
1349-
) -> "HashModel":
1349+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1350+
) -> "Model":
13501351
self.check()
13511352
db = self._get_db(pipeline)
13521353

@@ -1368,7 +1369,7 @@ async def all_pks(cls): # type: ignore
13681369
)
13691370

13701371
@classmethod
1371-
async def get(cls, pk: Any) -> "HashModel":
1372+
async def get(cls: Type["Model"], pk: Any) -> "Model":
13721373
document = await cls.db().hgetall(cls.make_primary_key(pk))
13731374
if not document:
13741375
raise NotFoundError
@@ -1513,8 +1514,8 @@ def __init__(self, *args, **kwargs):
15131514
super().__init__(*args, **kwargs)
15141515

15151516
async def save(
1516-
self, pipeline: Optional[redis.client.Pipeline] = None
1517-
) -> "JsonModel":
1517+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1518+
) -> "Model":
15181519
self.check()
15191520
db = self._get_db(pipeline)
15201521

@@ -1559,7 +1560,7 @@ async def update(self, **field_values):
15591560
await self.save()
15601561

15611562
@classmethod
1562-
async def get(cls, pk: Any) -> "JsonModel":
1563+
async def get(cls: Type["Model"], pk: Any) -> "Model":
15631564
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
15641565
if document == "null":
15651566
raise NotFoundError

0 commit comments

Comments
 (0)