Skip to content

Commit 03092ba

Browse files
SentinelManagedConnection searches for new master upon connection failure (redis#3560)
1 parent 4e59d24 commit 03092ba

File tree

4 files changed

+130
-14
lines changed

4 files changed

+130
-14
lines changed

redis/asyncio/sentinel.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2+
import inspect
23
import random
4+
import socket
35
import weakref
46
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
57

@@ -11,8 +13,13 @@
1113
SSLConnection,
1214
)
1315
from redis.commands import AsyncSentinelCommands
14-
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
15-
from redis.utils import str_if_bytes
16+
from redis.exceptions import (
17+
ConnectionError,
18+
ReadOnlyError,
19+
RedisError,
20+
ResponseError,
21+
TimeoutError,
22+
)
1623

1724

1825
class MasterNotFoundError(ConnectionError):
@@ -37,11 +44,47 @@ def __repr__(self):
3744

3845
async def connect_to(self, address):
3946
self.host, self.port = address
40-
await super().connect()
41-
if self.connection_pool.check_connection:
42-
await self.send_command("PING")
43-
if str_if_bytes(await self.read_response()) != "PONG":
44-
raise ConnectionError("PING failed")
47+
48+
if self.is_connected:
49+
return
50+
try:
51+
await self._connect()
52+
except asyncio.CancelledError:
53+
raise # in 3.7 and earlier, this is an Exception, not BaseException
54+
except (socket.timeout, asyncio.TimeoutError):
55+
raise TimeoutError("Timeout connecting to server")
56+
except OSError as e:
57+
raise ConnectionError(self._error_message(e))
58+
except Exception as exc:
59+
raise ConnectionError(exc) from exc
60+
61+
try:
62+
if not self.redis_connect_func:
63+
# Use the default on_connect function
64+
await self.on_connect_check_health(
65+
check_health=self.connection_pool.check_connection
66+
)
67+
else:
68+
# Use the passed function redis_connect_func
69+
(
70+
await self.redis_connect_func(self)
71+
if asyncio.iscoroutinefunction(self.redis_connect_func)
72+
else self.redis_connect_func(self)
73+
)
74+
except RedisError:
75+
# clean up after any error in on_connect
76+
await self.disconnect()
77+
raise
78+
79+
# run any user callbacks. right now the only internal callback
80+
# is for pubsub channel/pattern resubscription
81+
# first, remove any dead weakrefs
82+
self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
83+
for ref in self._connect_callbacks:
84+
callback = ref()
85+
task = callback(self)
86+
if task and inspect.isawaitable(task):
87+
await task
4588

4689
async def _connect_retry(self):
4790
if self._reader:

redis/sentinel.py

+45-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import random
2+
import socket
23
import weakref
34
from typing import Optional
45

56
from redis.client import Redis
67
from redis.commands import SentinelCommands
78
from redis.connection import Connection, ConnectionPool, SSLConnection
8-
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
9-
from redis.utils import str_if_bytes
9+
from redis.exceptions import (
10+
ConnectionError,
11+
ReadOnlyError,
12+
RedisError,
13+
ResponseError,
14+
TimeoutError,
15+
)
1016

1117

1218
class MasterNotFoundError(ConnectionError):
@@ -35,11 +41,39 @@ def __repr__(self):
3541

3642
def connect_to(self, address):
3743
self.host, self.port = address
38-
super().connect()
39-
if self.connection_pool.check_connection:
40-
self.send_command("PING")
41-
if str_if_bytes(self.read_response()) != "PONG":
42-
raise ConnectionError("PING failed")
44+
45+
if self._sock:
46+
return
47+
try:
48+
sock = self._connect()
49+
except socket.timeout:
50+
raise TimeoutError("Timeout connecting to server")
51+
except OSError as e:
52+
raise ConnectionError(self._error_message(e))
53+
54+
self._sock = sock
55+
try:
56+
if self.redis_connect_func is None:
57+
# Use the default on_connect function
58+
self.on_connect_check_health(
59+
check_health=self.connection_pool.check_connection
60+
)
61+
else:
62+
# Use the passed function redis_connect_func
63+
self.redis_connect_func(self)
64+
except RedisError:
65+
# clean up after any error in on_connect
66+
self.disconnect()
67+
raise
68+
69+
# run any user callbacks. right now the only internal callback
70+
# is for pubsub channel/pattern resubscription
71+
# first, remove any dead weakrefs
72+
self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
73+
for ref in self._connect_callbacks:
74+
callback = ref()
75+
if callback:
76+
callback(self)
4377

4478
def _connect_retry(self):
4579
if self._sock:
@@ -294,13 +328,16 @@ def discover_master(self, service_name):
294328
"""
295329
collected_errors = list()
296330
for sentinel_no, sentinel in enumerate(self.sentinels):
331+
# print(f"Sentinel: {sentinel_no}")
297332
try:
298333
masters = sentinel.sentinel_masters()
299334
except (ConnectionError, TimeoutError) as e:
300335
collected_errors.append(f"{sentinel} - {e!r}")
301336
continue
302337
state = masters.get(service_name)
338+
# print(f"Found master: {state}")
303339
if state and self.check_master_state(state, service_name):
340+
# print("Valid state")
304341
# Put this sentinel at the top of the list
305342
self.sentinels[0], self.sentinels[sentinel_no] = (
306343
sentinel,
@@ -313,6 +350,7 @@ def discover_master(self, service_name):
313350
else state["ip"]
314351
)
315352
return ip, state["port"]
353+
# print("Ignoring it")
316354

317355
error_info = ""
318356
if len(collected_errors) > 0:

tests/test_asyncio/test_sentinel_managed_connection.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ async def mock_connect():
3434
conn._connect.side_effect = mock_connect
3535
await conn.connect()
3636
assert conn._connect.call_count == 3
37+
assert connection_pool.get_master_address.call_count == 3
3738
await conn.disconnect()
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import socket
2+
3+
from redis.retry import Retry
4+
from redis.sentinel import SentinelManagedConnection
5+
from redis.backoff import NoBackoff
6+
from unittest import mock
7+
8+
9+
def test_connect_retry_on_timeout_error(master_host):
10+
"""Test that the _connect function is retried in case of a timeout"""
11+
connection_pool = mock.Mock()
12+
connection_pool.get_master_address = mock.Mock(
13+
return_value=(master_host[0], master_host[1])
14+
)
15+
conn = SentinelManagedConnection(
16+
retry_on_timeout=True,
17+
retry=Retry(NoBackoff(), 3),
18+
connection_pool=connection_pool,
19+
)
20+
origin_connect = conn._connect
21+
conn._connect = mock.Mock()
22+
23+
def mock_connect():
24+
# connect only on the last retry
25+
if conn._connect.call_count <= 2:
26+
raise socket.timeout
27+
else:
28+
return origin_connect()
29+
30+
conn._connect.side_effect = mock_connect
31+
conn.connect()
32+
assert conn._connect.call_count == 3
33+
assert connection_pool.get_master_address.call_count == 3
34+
conn.disconnect()

0 commit comments

Comments
 (0)