29
29
from redis .asyncio .connection import Connection , DefaultParser , SSLConnection , parse_url
30
30
from redis .asyncio .lock import Lock
31
31
from redis .asyncio .retry import Retry
32
+ from redis .auth .token import TokenInterface
32
33
from redis .backoff import default_backoff
33
34
from redis .client import EMPTY_RESPONSE , NEVER_DECODE , AbstractRedis
34
35
from redis .cluster import (
45
46
from redis .commands import READ_COMMANDS , AsyncRedisClusterCommands
46
47
from redis .crc import REDIS_CLUSTER_HASH_SLOTS , key_slot
47
48
from redis .credentials import CredentialProvider
49
+ from redis .event import AfterAsyncClusterInstantiationEvent , EventDispatcher
48
50
from redis .exceptions import (
49
51
AskError ,
50
52
BusyLoadingError ,
57
59
MaxConnectionsError ,
58
60
MovedError ,
59
61
RedisClusterException ,
62
+ RedisError ,
60
63
ResponseError ,
61
64
SlotNotCoveredError ,
62
65
TimeoutError ,
@@ -270,6 +273,7 @@ def __init__(
270
273
ssl_ciphers : Optional [str ] = None ,
271
274
protocol : Optional [int ] = 2 ,
272
275
address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
276
+ event_dispatcher : Optional [EventDispatcher ] = None ,
273
277
) -> None :
274
278
if db :
275
279
raise RedisClusterException (
@@ -366,11 +370,17 @@ def __init__(
366
370
if host and port :
367
371
startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
368
372
373
+ if event_dispatcher is None :
374
+ self ._event_dispatcher = EventDispatcher ()
375
+ else :
376
+ self ._event_dispatcher = event_dispatcher
377
+
369
378
self .nodes_manager = NodesManager (
370
379
startup_nodes ,
371
380
require_full_coverage ,
372
381
kwargs ,
373
382
address_remap = address_remap ,
383
+ event_dispatcher = self ._event_dispatcher ,
374
384
)
375
385
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
376
386
self .read_from_replicas = read_from_replicas
@@ -929,6 +939,8 @@ class ClusterNode:
929
939
__slots__ = (
930
940
"_connections" ,
931
941
"_free" ,
942
+ "_lock" ,
943
+ "_event_dispatcher" ,
932
944
"connection_class" ,
933
945
"connection_kwargs" ,
934
946
"host" ,
@@ -966,6 +978,9 @@ def __init__(
966
978
967
979
self ._connections : List [Connection ] = []
968
980
self ._free : Deque [Connection ] = collections .deque (maxlen = self .max_connections )
981
+ self ._event_dispatcher = self .connection_kwargs .get ("event_dispatcher" , None )
982
+ if self ._event_dispatcher is None :
983
+ self ._event_dispatcher = EventDispatcher ()
969
984
970
985
def __repr__ (self ) -> str :
971
986
return (
@@ -1082,10 +1097,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1082
1097
1083
1098
return ret
1084
1099
1100
+ async def re_auth_callback (self , token : TokenInterface ):
1101
+ tmp_queue = collections .deque ()
1102
+ while self ._free :
1103
+ conn = self ._free .popleft ()
1104
+ await conn .retry .call_with_retry (
1105
+ lambda : conn .send_command (
1106
+ "AUTH" , token .try_get ("oid" ), token .get_value ()
1107
+ ),
1108
+ lambda error : self ._mock (error ),
1109
+ )
1110
+ await conn .retry .call_with_retry (
1111
+ lambda : conn .read_response (), lambda error : self ._mock (error )
1112
+ )
1113
+ tmp_queue .append (conn )
1114
+
1115
+ while tmp_queue :
1116
+ conn = tmp_queue .popleft ()
1117
+ self ._free .append (conn )
1118
+
1119
+ async def _mock (self , error : RedisError ):
1120
+ """
1121
+ Dummy functions, needs to be passed as error callback to retry object.
1122
+ :param error:
1123
+ :return:
1124
+ """
1125
+ pass
1126
+
1085
1127
1086
1128
class NodesManager :
1087
1129
__slots__ = (
1088
1130
"_moved_exception" ,
1131
+ "_event_dispatcher" ,
1089
1132
"connection_kwargs" ,
1090
1133
"default_node" ,
1091
1134
"nodes_cache" ,
@@ -1102,6 +1145,7 @@ def __init__(
1102
1145
require_full_coverage : bool ,
1103
1146
connection_kwargs : Dict [str , Any ],
1104
1147
address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
1148
+ event_dispatcher : Optional [EventDispatcher ] = None ,
1105
1149
) -> None :
1106
1150
self .startup_nodes = {node .name : node for node in startup_nodes }
1107
1151
self .require_full_coverage = require_full_coverage
@@ -1113,6 +1157,10 @@ def __init__(
1113
1157
self .slots_cache : Dict [int , List ["ClusterNode" ]] = {}
1114
1158
self .read_load_balancer = LoadBalancer ()
1115
1159
self ._moved_exception : MovedError = None
1160
+ if event_dispatcher is None :
1161
+ self ._event_dispatcher = EventDispatcher ()
1162
+ else :
1163
+ self ._event_dispatcher = event_dispatcher
1116
1164
1117
1165
def get_node (
1118
1166
self ,
@@ -1230,6 +1278,12 @@ async def initialize(self) -> None:
1230
1278
try :
1231
1279
# Make sure cluster mode is enabled on this node
1232
1280
try :
1281
+ self ._event_dispatcher .dispatch (
1282
+ AfterAsyncClusterInstantiationEvent (
1283
+ self .nodes_cache ,
1284
+ self .connection_kwargs .get ("credential_provider" , None ),
1285
+ )
1286
+ )
1233
1287
cluster_slots = await startup_node .execute_command ("CLUSTER SLOTS" )
1234
1288
except ResponseError :
1235
1289
raise RedisClusterException (
0 commit comments