Skip to content

Commit c7fa2a9

Browse files
authored
Merge branch 'master' into add-dynamic-startup-nodes-flag-to-async-redis-cluster
2 parents dd2ee3f + 8e2f2d3 commit c7fa2a9

29 files changed

+3119
-33
lines changed

.github/actions/run-tests/action.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ runs:
103103

104104
if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then
105105
echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7"
106-
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod"
106+
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration"
107107
else
108108
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}"
109109
fi

.github/workflows/pypi-publish.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ name: Publish tag to Pypi
33
on:
44
release:
55
types: [published]
6+
workflow_dispatch:
67

78
permissions:
89
contents: read # to fetch code (actions/checkout)

dev_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ uvloop
1616
vulture>=2.3.0
1717
wheel>=0.30.0
1818
numpy>=1.24.0
19+
redis-entraid==0.1.0b1

pytest.ini

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ markers =
1010
asyncio: marker for async tests
1111
replica: replica tests
1212
experimental: run only experimental tests
13+
cp_integration: credential provider integration tests
1314
asyncio_mode = auto
1415
timeout = 30
1516
filterwarnings =

redis/asyncio/client.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
list_or_args,
5454
)
5555
from redis.credentials import CredentialProvider
56+
from redis.event import (
57+
AfterPooledConnectionsInstantiationEvent,
58+
AfterPubSubConnectionInstantiationEvent,
59+
AfterSingleConnectionInstantiationEvent,
60+
ClientType,
61+
EventDispatcher,
62+
)
5663
from redis.exceptions import (
5764
ConnectionError,
5865
ExecAbortError,
@@ -233,6 +240,7 @@ def __init__(
233240
redis_connect_func=None,
234241
credential_provider: Optional[CredentialProvider] = None,
235242
protocol: Optional[int] = 2,
243+
event_dispatcher: Optional[EventDispatcher] = None,
236244
):
237245
"""
238246
Initialize a new Redis client.
@@ -242,6 +250,10 @@ def __init__(
242250
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
243251
"""
244252
kwargs: Dict[str, Any]
253+
if event_dispatcher is None:
254+
self._event_dispatcher = EventDispatcher()
255+
else:
256+
self._event_dispatcher = event_dispatcher
245257
# auto_close_connection_pool only has an effect if connection_pool is
246258
# None. It is assumed that if connection_pool is not None, the user
247259
# wants to manage the connection pool themselves.
@@ -320,9 +332,19 @@ def __init__(
320332
# This arg only used if no pool is passed in
321333
self.auto_close_connection_pool = auto_close_connection_pool
322334
connection_pool = ConnectionPool(**kwargs)
335+
self._event_dispatcher.dispatch(
336+
AfterPooledConnectionsInstantiationEvent(
337+
[connection_pool], ClientType.ASYNC, credential_provider
338+
)
339+
)
323340
else:
324341
# If a pool is passed in, do not close it
325342
self.auto_close_connection_pool = False
343+
self._event_dispatcher.dispatch(
344+
AfterPooledConnectionsInstantiationEvent(
345+
[connection_pool], ClientType.ASYNC, credential_provider
346+
)
347+
)
326348

327349
self.connection_pool = connection_pool
328350
self.single_connection_client = single_connection_client
@@ -354,6 +376,12 @@ async def initialize(self: _RedisT) -> _RedisT:
354376
async with self._single_conn_lock:
355377
if self.connection is None:
356378
self.connection = await self.connection_pool.get_connection("_")
379+
380+
self._event_dispatcher.dispatch(
381+
AfterSingleConnectionInstantiationEvent(
382+
self.connection, ClientType.ASYNC, self._single_conn_lock
383+
)
384+
)
357385
return self
358386

359387
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -521,7 +549,9 @@ def pubsub(self, **kwargs) -> "PubSub":
521549
subscribe to channels and listen for messages that get published to
522550
them.
523551
"""
524-
return PubSub(self.connection_pool, **kwargs)
552+
return PubSub(
553+
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
554+
)
525555

526556
def monitor(self) -> "Monitor":
527557
return Monitor(self.connection_pool)
@@ -759,7 +789,12 @@ def __init__(
759789
ignore_subscribe_messages: bool = False,
760790
encoder=None,
761791
push_handler_func: Optional[Callable] = None,
792+
event_dispatcher: Optional["EventDispatcher"] = None,
762793
):
794+
if event_dispatcher is None:
795+
self._event_dispatcher = EventDispatcher()
796+
else:
797+
self._event_dispatcher = event_dispatcher
763798
self.connection_pool = connection_pool
764799
self.shard_hint = shard_hint
765800
self.ignore_subscribe_messages = ignore_subscribe_messages
@@ -876,6 +911,12 @@ async def connect(self):
876911
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
877912
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
878913

914+
self._event_dispatcher.dispatch(
915+
AfterPubSubConnectionInstantiationEvent(
916+
self.connection, self.connection_pool, ClientType.ASYNC, self._lock
917+
)
918+
)
919+
879920
async def _disconnect_raise_connect(self, conn, error):
880921
"""
881922
Close the connection and raise an exception

redis/asyncio/cluster.py

+54
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url
3030
from redis.asyncio.lock import Lock
3131
from redis.asyncio.retry import Retry
32+
from redis.auth.token import TokenInterface
3233
from redis.backoff import default_backoff
3334
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
3435
from redis.cluster import (
@@ -45,6 +46,7 @@
4546
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
4647
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
4748
from redis.credentials import CredentialProvider
49+
from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
4850
from redis.exceptions import (
4951
AskError,
5052
BusyLoadingError,
@@ -57,6 +59,7 @@
5759
MaxConnectionsError,
5860
MovedError,
5961
RedisClusterException,
62+
RedisError,
6063
ResponseError,
6164
SlotNotCoveredError,
6265
TimeoutError,
@@ -279,6 +282,7 @@ def __init__(
279282
ssl_ciphers: Optional[str] = None,
280283
protocol: Optional[int] = 2,
281284
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
285+
event_dispatcher: Optional[EventDispatcher] = None,
282286
) -> None:
283287
if db:
284288
raise RedisClusterException(
@@ -375,12 +379,18 @@ def __init__(
375379
if host and port:
376380
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
377381

382+
if event_dispatcher is None:
383+
self._event_dispatcher = EventDispatcher()
384+
else:
385+
self._event_dispatcher = event_dispatcher
386+
378387
self.nodes_manager = NodesManager(
379388
startup_nodes,
380389
require_full_coverage,
381390
kwargs,
382391
dynamic_startup_nodes=dynamic_startup_nodes,
383392
address_remap=address_remap,
393+
event_dispatcher=self._event_dispatcher,
384394
)
385395
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
386396
self.read_from_replicas = read_from_replicas
@@ -939,6 +949,8 @@ class ClusterNode:
939949
__slots__ = (
940950
"_connections",
941951
"_free",
952+
"_lock",
953+
"_event_dispatcher",
942954
"connection_class",
943955
"connection_kwargs",
944956
"host",
@@ -976,6 +988,9 @@ def __init__(
976988

977989
self._connections: List[Connection] = []
978990
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
991+
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
992+
if self._event_dispatcher is None:
993+
self._event_dispatcher = EventDispatcher()
979994

980995
def __repr__(self) -> str:
981996
return (
@@ -1092,10 +1107,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10921107

10931108
return ret
10941109

1110+
async def re_auth_callback(self, token: TokenInterface):
1111+
tmp_queue = collections.deque()
1112+
while self._free:
1113+
conn = self._free.popleft()
1114+
await conn.retry.call_with_retry(
1115+
lambda: conn.send_command(
1116+
"AUTH", token.try_get("oid"), token.get_value()
1117+
),
1118+
lambda error: self._mock(error),
1119+
)
1120+
await conn.retry.call_with_retry(
1121+
lambda: conn.read_response(), lambda error: self._mock(error)
1122+
)
1123+
tmp_queue.append(conn)
1124+
1125+
while tmp_queue:
1126+
conn = tmp_queue.popleft()
1127+
self._free.append(conn)
1128+
1129+
async def _mock(self, error: RedisError):
1130+
"""
1131+
Dummy functions, needs to be passed as error callback to retry object.
1132+
:param error:
1133+
:return:
1134+
"""
1135+
pass
1136+
10951137

10961138
class NodesManager:
10971139
__slots__ = (
10981140
"_moved_exception",
1141+
"_event_dispatcher",
10991142
"connection_kwargs",
11001143
"default_node",
11011144
"nodes_cache",
@@ -1114,6 +1157,7 @@ def __init__(
11141157
connection_kwargs: Dict[str, Any],
11151158
dynamic_startup_nodes: bool = True,
11161159
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1160+
event_dispatcher: Optional[EventDispatcher] = None,
11171161
) -> None:
11181162
self.startup_nodes = {node.name: node for node in startup_nodes}
11191163
self.require_full_coverage = require_full_coverage
@@ -1126,6 +1170,10 @@ def __init__(
11261170
self.slots_cache: Dict[int, List["ClusterNode"]] = {}
11271171
self.read_load_balancer = LoadBalancer()
11281172
self._moved_exception: MovedError = None
1173+
if event_dispatcher is None:
1174+
self._event_dispatcher = EventDispatcher()
1175+
else:
1176+
self._event_dispatcher = event_dispatcher
11291177

11301178
def get_node(
11311179
self,
@@ -1243,6 +1291,12 @@ async def initialize(self) -> None:
12431291
try:
12441292
# Make sure cluster mode is enabled on this node
12451293
try:
1294+
self._event_dispatcher.dispatch(
1295+
AfterAsyncClusterInstantiationEvent(
1296+
self.nodes_cache,
1297+
self.connection_kwargs.get("credential_provider", None),
1298+
)
1299+
)
12461300
cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
12471301
except ResponseError:
12481302
raise RedisClusterException(

0 commit comments

Comments
 (0)