29
29
from redis .asyncio .connection import Connection , DefaultParser , SSLConnection , parse_url
30
30
from redis .asyncio .lock import Lock
31
31
from redis .asyncio .retry import Retry
32
+ from redis .auth .token import TokenInterface
32
33
from redis .backoff import default_backoff
33
34
from redis .client import EMPTY_RESPONSE , NEVER_DECODE , AbstractRedis
34
35
from redis .cluster import (
45
46
from redis .commands import READ_COMMANDS , AsyncRedisClusterCommands
46
47
from redis .crc import REDIS_CLUSTER_HASH_SLOTS , key_slot
47
48
from redis .credentials import CredentialProvider
49
+ from redis .event import AfterAsyncClusterInstantiationEvent , EventDispatcher
48
50
from redis .exceptions import (
49
51
AskError ,
50
52
BusyLoadingError ,
57
59
MaxConnectionsError ,
58
60
MovedError ,
59
61
RedisClusterException ,
62
+ RedisError ,
60
63
ResponseError ,
61
64
SlotNotCoveredError ,
62
65
TimeoutError ,
@@ -279,6 +282,7 @@ def __init__(
279
282
ssl_ciphers : Optional [str ] = None ,
280
283
protocol : Optional [int ] = 2 ,
281
284
address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
285
+ event_dispatcher : Optional [EventDispatcher ] = None ,
282
286
) -> None :
283
287
if db :
284
288
raise RedisClusterException (
@@ -375,12 +379,18 @@ def __init__(
375
379
if host and port :
376
380
startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
377
381
382
+ if event_dispatcher is None :
383
+ self ._event_dispatcher = EventDispatcher ()
384
+ else :
385
+ self ._event_dispatcher = event_dispatcher
386
+
378
387
self .nodes_manager = NodesManager (
379
388
startup_nodes ,
380
389
require_full_coverage ,
381
390
kwargs ,
382
391
dynamic_startup_nodes = dynamic_startup_nodes ,
383
392
address_remap = address_remap ,
393
+ event_dispatcher = self ._event_dispatcher ,
384
394
)
385
395
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
386
396
self .read_from_replicas = read_from_replicas
@@ -939,6 +949,8 @@ class ClusterNode:
939
949
__slots__ = (
940
950
"_connections" ,
941
951
"_free" ,
952
+ "_lock" ,
953
+ "_event_dispatcher" ,
942
954
"connection_class" ,
943
955
"connection_kwargs" ,
944
956
"host" ,
@@ -976,6 +988,9 @@ def __init__(
976
988
977
989
self ._connections : List [Connection ] = []
978
990
self ._free : Deque [Connection ] = collections .deque (maxlen = self .max_connections )
991
+ self ._event_dispatcher = self .connection_kwargs .get ("event_dispatcher" , None )
992
+ if self ._event_dispatcher is None :
993
+ self ._event_dispatcher = EventDispatcher ()
979
994
980
995
def __repr__ (self ) -> str :
981
996
return (
@@ -1092,10 +1107,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
1092
1107
1093
1108
return ret
1094
1109
1110
+ async def re_auth_callback (self , token : TokenInterface ):
1111
+ tmp_queue = collections .deque ()
1112
+ while self ._free :
1113
+ conn = self ._free .popleft ()
1114
+ await conn .retry .call_with_retry (
1115
+ lambda : conn .send_command (
1116
+ "AUTH" , token .try_get ("oid" ), token .get_value ()
1117
+ ),
1118
+ lambda error : self ._mock (error ),
1119
+ )
1120
+ await conn .retry .call_with_retry (
1121
+ lambda : conn .read_response (), lambda error : self ._mock (error )
1122
+ )
1123
+ tmp_queue .append (conn )
1124
+
1125
+ while tmp_queue :
1126
+ conn = tmp_queue .popleft ()
1127
+ self ._free .append (conn )
1128
+
1129
+ async def _mock (self , error : RedisError ):
1130
+ """
1131
+ Dummy functions, needs to be passed as error callback to retry object.
1132
+ :param error:
1133
+ :return:
1134
+ """
1135
+ pass
1136
+
1095
1137
1096
1138
class NodesManager :
1097
1139
__slots__ = (
1098
1140
"_moved_exception" ,
1141
+ "_event_dispatcher" ,
1099
1142
"connection_kwargs" ,
1100
1143
"default_node" ,
1101
1144
"nodes_cache" ,
@@ -1114,6 +1157,7 @@ def __init__(
1114
1157
connection_kwargs : Dict [str , Any ],
1115
1158
dynamic_startup_nodes : bool = True ,
1116
1159
address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
1160
+ event_dispatcher : Optional [EventDispatcher ] = None ,
1117
1161
) -> None :
1118
1162
self .startup_nodes = {node .name : node for node in startup_nodes }
1119
1163
self .require_full_coverage = require_full_coverage
@@ -1126,6 +1170,10 @@ def __init__(
1126
1170
self .slots_cache : Dict [int , List ["ClusterNode" ]] = {}
1127
1171
self .read_load_balancer = LoadBalancer ()
1128
1172
self ._moved_exception : MovedError = None
1173
+ if event_dispatcher is None :
1174
+ self ._event_dispatcher = EventDispatcher ()
1175
+ else :
1176
+ self ._event_dispatcher = event_dispatcher
1129
1177
1130
1178
def get_node (
1131
1179
self ,
@@ -1243,6 +1291,12 @@ async def initialize(self) -> None:
1243
1291
try :
1244
1292
# Make sure cluster mode is enabled on this node
1245
1293
try :
1294
+ self ._event_dispatcher .dispatch (
1295
+ AfterAsyncClusterInstantiationEvent (
1296
+ self .nodes_cache ,
1297
+ self .connection_kwargs .get ("credential_provider" , None ),
1298
+ )
1299
+ )
1246
1300
cluster_slots = await startup_node .execute_command ("CLUSTER SLOTS" )
1247
1301
except ResponseError :
1248
1302
raise RedisClusterException (
0 commit comments