diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 73bbf2a2..03ccde85 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -47,6 +47,7 @@ model_registry = {} _T = TypeVar("_T") +Model = TypeVar("Model", bound="RedisModel") log = logging.getLogger(__name__) escaper = TokenEscaper() @@ -1160,7 +1161,7 @@ async def delete( return await cls._delete(db, cls.make_primary_key(pk)) @classmethod - async def get(cls, pk: Any) -> "RedisModel": + async def get(cls: Type["Model"], pk: Any) -> "Model": raise NotImplementedError async def update(self, **field_values): @@ -1168,8 +1169,8 @@ async def update(self, **field_values): raise NotImplementedError async def save( - self, pipeline: Optional[redis.client.Pipeline] = None - ) -> "RedisModel": + self: "Model", pipeline: Optional[redis.client.Pipeline] = None + ) -> "Model": raise NotImplementedError async def expire( @@ -1266,11 +1267,11 @@ def get_annotations(cls): @classmethod async def add( - cls, - models: Sequence["RedisModel"], + cls: Type["Model"], + models: Sequence["Model"], pipeline: Optional[redis.client.Pipeline] = None, pipeline_verifier: Callable[..., Any] = verify_pipeline_response, - ) -> Sequence["RedisModel"]: + ) -> Sequence["Model"]: db = cls._get_db(pipeline, bulk=True) for model in models: @@ -1345,8 +1346,8 @@ def __init_subclass__(cls, **kwargs): ) async def save( - self, pipeline: Optional[redis.client.Pipeline] = None - ) -> "HashModel": + self: "Model", pipeline: Optional[redis.client.Pipeline] = None + ) -> "Model": self.check() db = self._get_db(pipeline) @@ -1368,7 +1369,7 @@ async def all_pks(cls): # type: ignore ) @classmethod - async def get(cls, pk: Any) -> "HashModel": + async def get(cls: Type["Model"], pk: Any) -> "Model": document = await cls.db().hgetall(cls.make_primary_key(pk)) if not document: raise NotFoundError @@ -1513,8 +1514,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def save( - self, pipeline: Optional[redis.client.Pipeline] = None - ) -> "JsonModel": + self: "Model", pipeline: Optional[redis.client.Pipeline] = None + ) -> "Model": self.check() db = self._get_db(pipeline) @@ -1559,7 +1560,7 @@ async def update(self, **field_values): await self.save() @classmethod - async def get(cls, pk: Any) -> "JsonModel": + async def get(cls: Type["Model"], pk: Any) -> "Model": document = json.dumps(await cls.db().json().get(cls.make_key(pk))) if document == "null": raise NotFoundError