Skip to content

Commit 40e5fc1

Browse files
authored
StreamingCredentialProvider support (#3445)
* Added StreamingCredentialProvider interface * StreamingCredentialProvider support * Removed debug statement * Changed an approach to handle multiple connection pools * Added support for RedisCluster * Added dispatching of custom connection pool * Extended CredentialProvider interface with async API * Changed method implementation * Added support for async API * Removed unused lock * Added async API * Added support for single connection client * Added core functionality * Revert debug call * Added package to setup.py * Added handling of in-use connections * Added testing * Changed fixture name * Added marker * Marked tests with correct annotations * Added better cancelation handling * Removed another annotation * Added support for async cluster * Added pipeline tests * Added support for Pub/Sub * Added support for Pub/Sub in cluster * Added an option to parse endpoint from endpoints.json * Updated package names and ENV variables * Moved SSL certificates code into context of class * Fixed fixtures for async * Fixed test * Added better endpoitns handling * Changed variable names * Added logging * Fixed broken tests * Added TODO for SSL tests * Added error propagation to main thread * Added single connection lock * Codestyle fixes * Added missing methods * Removed wrong annotation * Fixed tests * Codestyle fix * Updated EventListener instantiation inside of class * Fixed variable name * Fixed variable names * Fixed variable name * Added EventException * Codestyle fix * Removed redundant code * Codestyle fix * Updated test case * Fixed tests * Fixed test * Removed dependency
1 parent 8f2276e commit 40e5fc1

28 files changed

+3117
-33
lines changed

.github/actions/run-tests/action.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ runs:
103103

104104
if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then
105105
echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7"
106-
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod"
106+
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration"
107107
else
108108
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}"
109109
fi

dev_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ uvloop
1616
vulture>=2.3.0
1717
wheel>=0.30.0
1818
numpy>=1.24.0
19+
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main

pytest.ini

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ markers =
1010
asyncio: marker for async tests
1111
replica: replica tests
1212
experimental: run only experimental tests
13+
cp_integration: credential provider integration tests
1314
asyncio_mode = auto
1415
timeout = 30
1516
filterwarnings =

redis/asyncio/client.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
list_or_args,
5454
)
5555
from redis.credentials import CredentialProvider
56+
from redis.event import (
57+
AfterPooledConnectionsInstantiationEvent,
58+
AfterPubSubConnectionInstantiationEvent,
59+
AfterSingleConnectionInstantiationEvent,
60+
ClientType,
61+
EventDispatcher,
62+
)
5663
from redis.exceptions import (
5764
ConnectionError,
5865
ExecAbortError,
@@ -233,6 +240,7 @@ def __init__(
233240
redis_connect_func=None,
234241
credential_provider: Optional[CredentialProvider] = None,
235242
protocol: Optional[int] = 2,
243+
event_dispatcher: Optional[EventDispatcher] = None,
236244
):
237245
"""
238246
Initialize a new Redis client.
@@ -242,6 +250,10 @@ def __init__(
242250
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
243251
"""
244252
kwargs: Dict[str, Any]
253+
if event_dispatcher is None:
254+
self._event_dispatcher = EventDispatcher()
255+
else:
256+
self._event_dispatcher = event_dispatcher
245257
# auto_close_connection_pool only has an effect if connection_pool is
246258
# None. It is assumed that if connection_pool is not None, the user
247259
# wants to manage the connection pool themselves.
@@ -320,9 +332,19 @@ def __init__(
320332
# This arg only used if no pool is passed in
321333
self.auto_close_connection_pool = auto_close_connection_pool
322334
connection_pool = ConnectionPool(**kwargs)
335+
self._event_dispatcher.dispatch(
336+
AfterPooledConnectionsInstantiationEvent(
337+
[connection_pool], ClientType.ASYNC, credential_provider
338+
)
339+
)
323340
else:
324341
# If a pool is passed in, do not close it
325342
self.auto_close_connection_pool = False
343+
self._event_dispatcher.dispatch(
344+
AfterPooledConnectionsInstantiationEvent(
345+
[connection_pool], ClientType.ASYNC, credential_provider
346+
)
347+
)
326348

327349
self.connection_pool = connection_pool
328350
self.single_connection_client = single_connection_client
@@ -354,6 +376,12 @@ async def initialize(self: _RedisT) -> _RedisT:
354376
async with self._single_conn_lock:
355377
if self.connection is None:
356378
self.connection = await self.connection_pool.get_connection("_")
379+
380+
self._event_dispatcher.dispatch(
381+
AfterSingleConnectionInstantiationEvent(
382+
self.connection, ClientType.ASYNC, self._single_conn_lock
383+
)
384+
)
357385
return self
358386

359387
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -521,7 +549,9 @@ def pubsub(self, **kwargs) -> "PubSub":
521549
subscribe to channels and listen for messages that get published to
522550
them.
523551
"""
524-
return PubSub(self.connection_pool, **kwargs)
552+
return PubSub(
553+
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
554+
)
525555

526556
def monitor(self) -> "Monitor":
527557
return Monitor(self.connection_pool)
@@ -759,7 +789,12 @@ def __init__(
759789
ignore_subscribe_messages: bool = False,
760790
encoder=None,
761791
push_handler_func: Optional[Callable] = None,
792+
event_dispatcher: Optional["EventDispatcher"] = None,
762793
):
794+
if event_dispatcher is None:
795+
self._event_dispatcher = EventDispatcher()
796+
else:
797+
self._event_dispatcher = event_dispatcher
763798
self.connection_pool = connection_pool
764799
self.shard_hint = shard_hint
765800
self.ignore_subscribe_messages = ignore_subscribe_messages
@@ -876,6 +911,12 @@ async def connect(self):
876911
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
877912
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
878913

914+
self._event_dispatcher.dispatch(
915+
AfterPubSubConnectionInstantiationEvent(
916+
self.connection, self.connection_pool, ClientType.ASYNC, self._lock
917+
)
918+
)
919+
879920
async def _disconnect_raise_connect(self, conn, error):
880921
"""
881922
Close the connection and raise an exception

redis/asyncio/cluster.py

+54
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url
3030
from redis.asyncio.lock import Lock
3131
from redis.asyncio.retry import Retry
32+
from redis.auth.token import TokenInterface
3233
from redis.backoff import default_backoff
3334
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
3435
from redis.cluster import (
@@ -45,6 +46,7 @@
4546
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
4647
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
4748
from redis.credentials import CredentialProvider
49+
from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
4850
from redis.exceptions import (
4951
AskError,
5052
BusyLoadingError,
@@ -57,6 +59,7 @@
5759
MaxConnectionsError,
5860
MovedError,
5961
RedisClusterException,
62+
RedisError,
6063
ResponseError,
6164
SlotNotCoveredError,
6265
TimeoutError,
@@ -270,6 +273,7 @@ def __init__(
270273
ssl_ciphers: Optional[str] = None,
271274
protocol: Optional[int] = 2,
272275
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
276+
event_dispatcher: Optional[EventDispatcher] = None,
273277
) -> None:
274278
if db:
275279
raise RedisClusterException(
@@ -366,11 +370,17 @@ def __init__(
366370
if host and port:
367371
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
368372

373+
if event_dispatcher is None:
374+
self._event_dispatcher = EventDispatcher()
375+
else:
376+
self._event_dispatcher = event_dispatcher
377+
369378
self.nodes_manager = NodesManager(
370379
startup_nodes,
371380
require_full_coverage,
372381
kwargs,
373382
address_remap=address_remap,
383+
event_dispatcher=self._event_dispatcher,
374384
)
375385
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
376386
self.read_from_replicas = read_from_replicas
@@ -929,6 +939,8 @@ class ClusterNode:
929939
__slots__ = (
930940
"_connections",
931941
"_free",
942+
"_lock",
943+
"_event_dispatcher",
932944
"connection_class",
933945
"connection_kwargs",
934946
"host",
@@ -966,6 +978,9 @@ def __init__(
966978

967979
self._connections: List[Connection] = []
968980
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
981+
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
982+
if self._event_dispatcher is None:
983+
self._event_dispatcher = EventDispatcher()
969984

970985
def __repr__(self) -> str:
971986
return (
@@ -1082,10 +1097,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10821097

10831098
return ret
10841099

1100+
async def re_auth_callback(self, token: TokenInterface):
1101+
tmp_queue = collections.deque()
1102+
while self._free:
1103+
conn = self._free.popleft()
1104+
await conn.retry.call_with_retry(
1105+
lambda: conn.send_command(
1106+
"AUTH", token.try_get("oid"), token.get_value()
1107+
),
1108+
lambda error: self._mock(error),
1109+
)
1110+
await conn.retry.call_with_retry(
1111+
lambda: conn.read_response(), lambda error: self._mock(error)
1112+
)
1113+
tmp_queue.append(conn)
1114+
1115+
while tmp_queue:
1116+
conn = tmp_queue.popleft()
1117+
self._free.append(conn)
1118+
1119+
async def _mock(self, error: RedisError):
1120+
"""
1121+
Dummy functions, needs to be passed as error callback to retry object.
1122+
:param error:
1123+
:return:
1124+
"""
1125+
pass
1126+
10851127

10861128
class NodesManager:
10871129
__slots__ = (
10881130
"_moved_exception",
1131+
"_event_dispatcher",
10891132
"connection_kwargs",
10901133
"default_node",
10911134
"nodes_cache",
@@ -1102,6 +1145,7 @@ def __init__(
11021145
require_full_coverage: bool,
11031146
connection_kwargs: Dict[str, Any],
11041147
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
1148+
event_dispatcher: Optional[EventDispatcher] = None,
11051149
) -> None:
11061150
self.startup_nodes = {node.name: node for node in startup_nodes}
11071151
self.require_full_coverage = require_full_coverage
@@ -1113,6 +1157,10 @@ def __init__(
11131157
self.slots_cache: Dict[int, List["ClusterNode"]] = {}
11141158
self.read_load_balancer = LoadBalancer()
11151159
self._moved_exception: MovedError = None
1160+
if event_dispatcher is None:
1161+
self._event_dispatcher = EventDispatcher()
1162+
else:
1163+
self._event_dispatcher = event_dispatcher
11161164

11171165
def get_node(
11181166
self,
@@ -1230,6 +1278,12 @@ async def initialize(self) -> None:
12301278
try:
12311279
# Make sure cluster mode is enabled on this node
12321280
try:
1281+
self._event_dispatcher.dispatch(
1282+
AfterAsyncClusterInstantiationEvent(
1283+
self.nodes_cache,
1284+
self.connection_kwargs.get("credential_provider", None),
1285+
)
1286+
)
12331287
cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
12341288
except ResponseError:
12351289
raise RedisClusterException(

0 commit comments

Comments
 (0)