Skip to content

Commit e587082

Browse files
Add Redis readiness verification (redis#3555)
1 parent 04589d4 commit e587082

11 files changed

+366
-49
lines changed

redis/asyncio/client.py

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
encoding: str = "utf-8",
225225
encoding_errors: str = "strict",
226226
decode_responses: bool = False,
227+
check_ready: bool = False,
227228
retry_on_timeout: bool = False,
228229
retry_on_error: Optional[list] = None,
229230
ssl: bool = False,
@@ -291,6 +292,7 @@ def __init__(
291292
"encoding": encoding,
292293
"encoding_errors": encoding_errors,
293294
"decode_responses": decode_responses,
295+
"check_ready": check_ready,
294296
"retry_on_timeout": retry_on_timeout,
295297
"retry_on_error": retry_on_error,
296298
"retry": copy.deepcopy(retry),

redis/asyncio/cluster.py

+2
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def __init__(
258258
encoding_errors: str = "strict",
259259
decode_responses: bool = False,
260260
# Connection related kwargs
261+
check_ready: bool = False,
261262
health_check_interval: float = 0,
262263
socket_connect_timeout: Optional[float] = None,
263264
socket_keepalive: bool = False,
@@ -313,6 +314,7 @@ def __init__(
313314
"encoding_errors": encoding_errors,
314315
"decode_responses": decode_responses,
315316
# Connection related kwargs
317+
"check_ready": check_ready,
316318
"health_check_interval": health_check_interval,
317319
"socket_connect_timeout": socket_connect_timeout,
318320
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
encoding_errors: str = "strict",
149149
decode_responses: bool = False,
150150
parser_class: Type[BaseParser] = DefaultParser,
151+
check_ready: bool = False,
151152
socket_read_size: int = 65536,
152153
health_check_interval: float = 0,
153154
client_name: Optional[str] = None,
@@ -204,6 +205,7 @@ def __init__(
204205
self.health_check_interval = health_check_interval
205206
self.next_health_check: float = -1
206207
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
208+
self.check_ready = check_ready
207209
self.redis_connect_func = redis_connect_func
208210
self._reader: Optional[asyncio.StreamReader] = None
209211
self._writer: Optional[asyncio.StreamWriter] = None
@@ -295,14 +297,48 @@ async def connect(self):
295297
"""Connects to the Redis server if not already connected"""
296298
await self.connect_check_health(check_health=True)
297299

300+
async def _connect_check_ready(self):
301+
await self._connect()
302+
303+
# Doing handshake since connect and send operations work even when Redis is not ready
304+
if self.check_ready:
305+
try:
306+
ping_cmd = self.pack_command("PING")
307+
if self.socket_timeout:
308+
await asyncio.wait_for(
309+
self._send_packed_command(ping_cmd), self.socket_timeout
310+
)
311+
else:
312+
await self._send_packed_command(ping_cmd)
313+
314+
if self.socket_timeout is not None:
315+
async with async_timeout(self.socket_timeout):
316+
response = str_if_bytes(await self._reader.read(1024))
317+
else:
318+
response = str_if_bytes(await self._reader.read(1024))
319+
320+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
321+
raise ResponseError(f"Invalid PING response: {response}")
322+
except (
323+
socket.timeout,
324+
asyncio.TimeoutError,
325+
ResponseError,
326+
ConnectionResetError,
327+
) as e:
328+
# `socket_keepalive_options` might contain invalid options
329+
# causing an error. Do not leave the connection open.
330+
self._close()
331+
raise ConnectionError(self._error_message(e))
332+
298333
async def connect_check_health(self, check_health: bool = True):
299334
if self.is_connected:
300335
return
301336
try:
302337
await self.retry.call_with_retry(
303-
lambda: self._connect(), lambda error: self.disconnect()
338+
lambda: self._connect_check_ready(), lambda error: self.disconnect()
304339
)
305340
except asyncio.CancelledError:
341+
self._close()
306342
raise # in 3.7 and earlier, this is an Exception, not BaseException
307343
except (socket.timeout, asyncio.TimeoutError):
308344
raise TimeoutError("Timeout connecting to server")
@@ -526,8 +562,7 @@ async def send_packed_command(
526562
self._send_packed_command(command), self.socket_timeout
527563
)
528564
else:
529-
self._writer.writelines(command)
530-
await self._writer.drain()
565+
await self._send_packed_command(command)
531566
except asyncio.TimeoutError:
532567
await self.disconnect(nowait=True)
533568
raise TimeoutError("Timeout writing to socket") from None
@@ -774,7 +809,7 @@ async def _connect(self):
774809
except (OSError, TypeError):
775810
# `socket_keepalive_options` might contain invalid options
776811
# causing an error. Do not leave the connection open.
777-
writer.close()
812+
self._close()
778813
raise
779814

780815
def _host_error(self) -> str:
@@ -933,7 +968,6 @@ async def _connect(self):
933968
reader, writer = await asyncio.open_unix_connection(path=self.path)
934969
self._reader = reader
935970
self._writer = writer
936-
await self.on_connect()
937971

938972
def _host_error(self) -> str:
939973
return self.path

redis/client.py

+2
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __init__(
206206
charset: Optional[str] = None,
207207
errors: Optional[str] = None,
208208
decode_responses: bool = False,
209+
check_ready: bool = False,
209210
retry_on_timeout: bool = False,
210211
retry_on_error: Optional[List[Type[Exception]]] = None,
211212
ssl: bool = False,
@@ -282,6 +283,7 @@ def __init__(
282283
"encoding": encoding,
283284
"encoding_errors": encoding_errors,
284285
"decode_responses": decode_responses,
286+
"check_ready": check_ready,
285287
"retry_on_error": retry_on_error,
286288
"retry": copy.deepcopy(retry),
287289
"max_connections": max_connections,

redis/connection.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __init__(
236236
encoding: str = "utf-8",
237237
encoding_errors: str = "strict",
238238
decode_responses: bool = False,
239+
check_ready: bool = False,
239240
parser_class=DefaultParser,
240241
socket_read_size: int = 65536,
241242
health_check_interval: int = 0,
@@ -302,6 +303,7 @@ def __init__(
302303
self.redis_connect_func = redis_connect_func
303304
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
304305
self.handshake_metadata = None
306+
self.check_ready = check_ready
305307
self._sock = None
306308
self._socket_read_size = socket_read_size
307309
self.set_parser(parser_class)
@@ -378,12 +380,35 @@ def connect(self):
378380
"Connects to the Redis server if not already connected"
379381
self.connect_check_health(check_health=True)
380382

383+
def _connect_check_ready(self):
384+
sock = self._connect()
385+
386+
# Doing handshake since connect and send operations work even when Redis is not ready
387+
if self.check_ready:
388+
try:
389+
ping_parts = self._command_packer.pack("PING")
390+
for part in ping_parts:
391+
sock.sendall(part)
392+
393+
response = str_if_bytes(sock.recv(1024))
394+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
395+
raise ResponseError(f"Invalid PING response: {response}")
396+
except (ConnectionResetError, ResponseError) as err:
397+
try:
398+
sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
399+
except OSError:
400+
pass
401+
sock.close()
402+
raise ConnectionError(self._error_message(err))
403+
return sock
404+
381405
def connect_check_health(self, check_health: bool = True):
382406
if self._sock:
383407
return
384408
try:
385409
sock = self.retry.call_with_retry(
386-
lambda: self._connect(), lambda error: self.disconnect(error)
410+
lambda: self._connect_check_ready(),
411+
lambda error: self.disconnect(error),
387412
)
388413
except socket.timeout:
389414
raise TimeoutError("Timeout connecting to server")

tests/test_asyncio/test_cluster.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ async def test_reading_with_load_balancing_strategies(
716716
Connection,
717717
send_command=mock.DEFAULT,
718718
read_response=mock.DEFAULT,
719-
_connect=mock.DEFAULT,
719+
_connect_check_ready=mock.DEFAULT,
720720
can_read_destructive=mock.DEFAULT,
721721
on_connect=mock.DEFAULT,
722722
) as mocks:
@@ -748,7 +748,7 @@ def execute_command_mock_third(self, *args, **options):
748748
execute_command.side_effect = execute_command_mock_first
749749
mocks["send_command"].return_value = True
750750
mocks["read_response"].return_value = "OK"
751-
mocks["_connect"].return_value = True
751+
mocks["_connect_check_ready"].return_value = True
752752
mocks["can_read_destructive"].return_value = False
753753
mocks["on_connect"].return_value = True
754754

@@ -3090,13 +3090,17 @@ async def execute_command(self, *args, **kwargs):
30903090

30913091
return _create_client
30923092

3093+
@pytest.mark.parametrize("check_ready", [True, False])
30933094
async def test_ssl_connection_without_ssl(
3094-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3095+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_ready
30953096
) -> None:
30963097
with pytest.raises(RedisClusterException) as e:
3097-
await create_client(mocked=False, ssl=False)
3098+
await create_client(mocked=False, ssl=False, check_ready=check_ready)
30983099
e = e.value.__cause__
3099-
assert "Connection closed by server" in str(e)
3100+
if check_ready:
3101+
assert "Invalid PING response" in str(e)
3102+
else:
3103+
assert "Connection closed by server" in str(e)
31003104

31013105
async def test_ssl_with_invalid_cert(
31023106
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)