diff --git a/CHANGES b/CHANGES index bd96846b6d..031d909f23 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Fix lock.extend() typedef to accept float TTL extension * Update URL in the readme linking to Redis University * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py index bb2cccab52..f70a8d09ab 100644 --- a/redis/asyncio/lock.py +++ b/redis/asyncio/lock.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Awaitable, Optional, Union from redis.exceptions import LockError, LockNotOwnedError +from redis.typing import Number if TYPE_CHECKING: from redis.asyncio import Redis, RedisCluster @@ -82,7 +83,7 @@ def __init__( timeout: Optional[float] = None, sleep: float = 0.1, blocking: bool = True, - blocking_timeout: Optional[float] = None, + blocking_timeout: Optional[Number] = None, thread_local: bool = True, ): """ @@ -167,7 +168,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): async def acquire( self, blocking: Optional[bool] = None, - blocking_timeout: Optional[float] = None, + blocking_timeout: Optional[Number] = None, token: Optional[Union[str, bytes]] = None, ): """ @@ -265,7 +266,7 @@ async def do_release(self, expected_token: bytes) -> None: raise LockNotOwnedError("Cannot release a lock that's no longer owned") def extend( - self, additional_time: float, replace_ttl: bool = False + self, additional_time: Number, replace_ttl: bool = False ) -> Awaitable[bool]: """ Adds more time to an already acquired lock. diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 033a8b7467..9973ef701f 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -174,11 +174,12 @@ async def test_extend_lock_replace_ttl(self, r): await lock.release() async def test_extend_lock_float(self, r): - lock = self.get_lock(r, "foo", timeout=10.0) + lock = self.get_lock(r, "foo", timeout=10.5) assert await lock.acquire(blocking=False) - assert 8000 < (await r.pttl("foo")) <= 10000 - assert await lock.extend(10.0) - assert 16000 < (await r.pttl("foo")) <= 20000 + assert 10400 < (await r.pttl("foo")) <= 10500 + old_ttl = await r.pttl("foo") + assert await lock.extend(10.5) + assert old_ttl + 10400 < (await r.pttl("foo")) <= old_ttl + 10500 await lock.release() async def test_extending_unlocked_lock_raises_error(self, r): diff --git a/tests/test_lock.py b/tests/test_lock.py index d77ff9717a..136c86e459 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -178,11 +178,12 @@ def test_extend_lock_replace_ttl(self, r): lock.release() def test_extend_lock_float(self, r): - lock = self.get_lock(r, "foo", timeout=10.0) + lock = self.get_lock(r, "foo", timeout=10.5) assert lock.acquire(blocking=False) - assert 8000 < r.pttl("foo") <= 10000 - assert lock.extend(10.0) - assert 16000 < r.pttl("foo") <= 20000 + assert 10400 < r.pttl("foo") <= 10500 + old_ttl = r.pttl("foo") + assert lock.extend(10.5) + assert old_ttl + 10400 < r.pttl("foo") <= old_ttl + 10500 lock.release() def test_extending_unlocked_lock_raises_error(self, r):