@@ -148,6 +148,7 @@ def __init__(
148
148
encoding_errors : str = "strict" ,
149
149
decode_responses : bool = False ,
150
150
parser_class : Type [BaseParser ] = DefaultParser ,
151
+ check_ready : bool = False ,
151
152
socket_read_size : int = 65536 ,
152
153
health_check_interval : float = 0 ,
153
154
client_name : Optional [str ] = None ,
@@ -204,6 +205,7 @@ def __init__(
204
205
self .health_check_interval = health_check_interval
205
206
self .next_health_check : float = - 1
206
207
self .encoder = encoder_class (encoding , encoding_errors , decode_responses )
208
+ self .check_ready = check_ready
207
209
self .redis_connect_func = redis_connect_func
208
210
self ._reader : Optional [asyncio .StreamReader ] = None
209
211
self ._writer : Optional [asyncio .StreamWriter ] = None
@@ -295,14 +297,48 @@ async def connect(self):
295
297
"""Connects to the Redis server if not already connected"""
296
298
await self .connect_check_health (check_health = True )
297
299
300
+ async def _connect_check_ready (self ):
301
+ await self ._connect ()
302
+
303
+ # Doing handshake since connect and send operations work even when Redis is not ready
304
+ if self .check_ready :
305
+ try :
306
+ ping_cmd = self .pack_command ("PING" )
307
+ if self .socket_timeout :
308
+ await asyncio .wait_for (
309
+ self ._send_packed_command (ping_cmd ), self .socket_timeout
310
+ )
311
+ else :
312
+ await self ._send_packed_command (ping_cmd )
313
+
314
+ if self .socket_timeout is not None :
315
+ async with async_timeout (self .socket_timeout ):
316
+ response = str_if_bytes (await self ._reader .read (1024 ))
317
+ else :
318
+ response = str_if_bytes (await self ._reader .read (1024 ))
319
+
320
+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
321
+ raise ResponseError (f"Invalid PING response: { response } " )
322
+ except (
323
+ socket .timeout ,
324
+ asyncio .TimeoutError ,
325
+ ResponseError ,
326
+ ConnectionResetError ,
327
+ ) as e :
328
+ # `socket_keepalive_options` might contain invalid options
329
+ # causing an error. Do not leave the connection open.
330
+ self ._close ()
331
+ raise ConnectionError (self ._error_message (e ))
332
+
298
333
async def connect_check_health (self , check_health : bool = True ):
299
334
if self .is_connected :
300
335
return
301
336
try :
302
337
await self .retry .call_with_retry (
303
- lambda : self ._connect (), lambda error : self .disconnect ()
338
+ lambda : self ._connect_check_ready (), lambda error : self .disconnect ()
304
339
)
305
340
except asyncio .CancelledError :
341
+ self ._close ()
306
342
raise # in 3.7 and earlier, this is an Exception, not BaseException
307
343
except (socket .timeout , asyncio .TimeoutError ):
308
344
raise TimeoutError ("Timeout connecting to server" )
@@ -526,8 +562,7 @@ async def send_packed_command(
526
562
self ._send_packed_command (command ), self .socket_timeout
527
563
)
528
564
else :
529
- self ._writer .writelines (command )
530
- await self ._writer .drain ()
565
+ await self ._send_packed_command (command )
531
566
except asyncio .TimeoutError :
532
567
await self .disconnect (nowait = True )
533
568
raise TimeoutError ("Timeout writing to socket" ) from None
@@ -774,7 +809,7 @@ async def _connect(self):
774
809
except (OSError , TypeError ):
775
810
# `socket_keepalive_options` might contain invalid options
776
811
# causing an error. Do not leave the connection open.
777
- writer . close ()
812
+ self . _close ()
778
813
raise
779
814
780
815
def _host_error (self ) -> str :
@@ -933,7 +968,6 @@ async def _connect(self):
933
968
reader , writer = await asyncio .open_unix_connection (path = self .path )
934
969
self ._reader = reader
935
970
self ._writer = writer
936
- await self .on_connect ()
937
971
938
972
def _host_error (self ) -> str :
939
973
return self .path
0 commit comments