Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 41 additions & 42 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def __init__(
self._pool_lock = pool_lock
self._cache = cache
self._cache_lock = threading.RLock()
self._current_command_cache_key = None
self._current_command_cache_entry = None
self._current_options = None
self.register_connect_callback(self._enable_tracking_callback)

Expand Down Expand Up @@ -1453,42 +1453,49 @@ def send_command(self, *args, **kwargs):
if not self._cache.is_cachable(
CacheKey(command=args[0], redis_keys=(), redis_args=())
):
self._current_command_cache_key = None
self._current_command_cache_entry = None
self._conn.send_command(*args, **kwargs)
return

if kwargs.get("keys") is None:
raise ValueError("Cannot create cache key.")

# Creates cache key.
self._current_command_cache_key = CacheKey(
cache_key = CacheKey(
command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
)
self._current_command_cache_entry = None

with self._cache_lock:
# We have to trigger invalidation processing in case if
# it was cached by another connection to avoid
# queueing invalidations in stale connections.
if self._cache.get(self._current_command_cache_key):
entry = self._cache.get(self._current_command_cache_key)

if entry.connection_ref != self._conn:
cache_entry = self._cache.get(cache_key)
if cache_entry is not None and cache_entry.status == CacheEntryStatus.VALID:
# We have to trigger invalidation processing in case if
# it was cached by another connection to avoid
# queueing invalidations in stale connections.
if cache_entry.connection_ref != self._conn:
with self._pool_lock:
while entry.connection_ref.can_read():
entry.connection_ref.read_response(push_request=True)

return
while cache_entry.connection_ref.can_read():
cache_entry.connection_ref.read_response(push_request=True)
# Check if entry still exists, if so it must be cache_entry we got because of cache_lock
if self._cache.get(cache_key) is not None:
self._current_command_cache_entry = cache_entry
return
cache_entry = None
else:
self._current_command_cache_entry = cache_entry
return

# Set temporary entry value to prevent
# race condition from another connection.
self._cache.set(
CacheEntry(
cache_key=self._current_command_cache_key,
if cache_entry is None:
# Creates cache entry.
cache_entry = CacheEntry(
cache_key=cache_key,
cache_value=self.DUMMY_CACHE_VALUE,
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=self._conn,
)
)
# Set temporary entry value to prevent
# race condition from another connection.
self._cache.set(cache_entry)
self._current_command_cache_entry = cache_entry

# Send command over socket only if it's allowed
# read-only command that not yet cached.
Expand All @@ -1501,17 +1508,13 @@ def read_response(
self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
):
with self._cache_lock:
# Check if command response exists in a cache and it's not in progress.
# Check if command response cache entry exists and it's valid.
if (
self._current_command_cache_key is not None
and self._cache.get(self._current_command_cache_key) is not None
and self._cache.get(self._current_command_cache_key).status
!= CacheEntryStatus.IN_PROGRESS
self._current_command_cache_entry is not None
and self._current_command_cache_entry.status == CacheEntryStatus.VALID
):
res = copy.deepcopy(
self._cache.get(self._current_command_cache_key).cache_value
)
self._current_command_cache_key = None
res = copy.deepcopy(self._current_command_cache_entry.cache_value)
self._current_command_cache_entry = None
return res

response = self._conn.read_response(
Expand All @@ -1522,23 +1525,19 @@ def read_response(

with self._cache_lock:
# Prevent not-allowed command from caching.
if self._current_command_cache_key is None:
if self._current_command_cache_entry is None:
return response
# If response is None prevent from caching.
cache_key = self._current_command_cache_entry.cache_key
if response is None:
self._cache.delete_by_cache_keys([self._current_command_cache_key])
self._cache.delete_by_cache_keys([cache_key])
self._current_command_cache_entry = None
return response

cache_entry = self._cache.get(self._current_command_cache_key)

# Cache only responses that still valid
# and wasn't invalidated by another connection in meantime.
if cache_entry is not None:
cache_entry.status = CacheEntryStatus.VALID
cache_entry.cache_value = response
self._cache.set(cache_entry)

self._current_command_cache_key = None
# No bother entry exists in cache or not, this ensures we don't overwrite another entry.
self._current_command_cache_entry.status = CacheEntryStatus.VALID
self._current_command_cache_entry.cache_value = response
self._current_command_cache_entry = None

return response

Expand Down
141 changes: 114 additions & 27 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,6 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
mock_cache.is_cachable.return_value = True
mock_cache.get.side_effect = [
None,
None,
CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE,
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=mock_connection,
),
CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
Expand Down Expand Up @@ -526,23 +517,15 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})
assert proxy_connection.read_response() == b"bar"
assert proxy_connection._current_command_cache_key is None
assert proxy_connection._current_command_cache_entry is None

# cached reply
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})
assert proxy_connection.read_response() == b"bar"
assert proxy_connection._current_command_cache_entry is None

mock_cache.set.assert_has_calls(
[
call(
CacheEntry(
cache_key=CacheKey(
command="GET",
redis_keys=("foo",),
redis_args=("GET", "foo"),
),
cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE,
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=mock_connection,
)
),
call(
CacheEntry(
cache_key=CacheKey(
Expand Down Expand Up @@ -570,11 +553,6 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
)
),
call(
CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
)
),
]
)

Expand Down Expand Up @@ -613,3 +591,112 @@ def test_triggers_invalidation_processing_on_another_connection(
assert proxy_connection.read_response() == b"bar"
assert another_conn.can_read.call_count == 2
another_conn.read_response.assert_called_once()

@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="Pypy doesn't support side_effect",
)
def test_cache_entry_in_progress(self, mock_cache, mock_connection):
mock_connection.retry = "mock"
mock_connection.host = "mock"
mock_connection.port = "mock"
mock_connection.credential_provider = UsernamePasswordCredentialProvider()

another_conn = copy.deepcopy(mock_connection)
another_conn.can_read.return_value = False
cache_entry = CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=another_conn,
)
mock_cache.is_cachable.return_value = True
mock_cache.get.return_value = cache_entry
mock_connection.can_read.return_value = False
mock_connection.read_response.return_value = b"bar2"

proxy_connection = CacheProxyConnection(
mock_connection, mock_cache, threading.RLock()
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})

assert proxy_connection.read_response() == b"bar2"
mock_connection.send_command.assert_called_once()
mock_connection.read_response.assert_called_once()

@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="Pypy doesn't support side_effect",
)
def test_cache_entry_gone_between_send_and_read(self, mock_cache, mock_connection):
mock_connection.retry = "mock"
mock_connection.host = "mock"
mock_connection.port = "mock"
mock_connection.credential_provider = UsernamePasswordCredentialProvider()

another_conn = copy.deepcopy(mock_connection)
another_conn.can_read.return_value = False
cache_entry = CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=another_conn,
)
mock_cache.is_cachable.return_value = True
mock_cache.get.return_value = cache_entry
mock_connection.can_read.return_value = False
mock_connection.read_response.return_value = None

proxy_connection = CacheProxyConnection(
mock_connection, mock_cache, threading.RLock()
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})

# cache entry gone
mock_cache.get.return_value = None

assert proxy_connection.read_response() == b"bar"
mock_connection.send_command.assert_not_called()
mock_connection.read_response.assert_not_called()

@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="Pypy doesn't support side_effect",
)
def test_cache_entry_fill_between_send_and_read(self, mock_cache, mock_connection):
mock_connection.retry = "mock"
mock_connection.host = "mock"
mock_connection.port = "mock"
mock_connection.credential_provider = UsernamePasswordCredentialProvider()

another_conn = copy.deepcopy(mock_connection)
another_conn.can_read.return_value = False

mock_cache.is_cachable.return_value = True
mock_cache.get.return_value = None
mock_connection.can_read.return_value = False
mock_connection.read_response.return_value = b"bar2"

proxy_connection = CacheProxyConnection(
mock_connection, mock_cache, threading.RLock()
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})

cache_entry = CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=another_conn,
)
# cache entry fill
mock_cache.get.return_value = cache_entry

assert proxy_connection.read_response() == b"bar2"
mock_connection.send_command.assert_called_once()
mock_connection.read_response.assert_called_once()