Skip to content

Commit d8ab09b

Browse files
committed
Add automatic reconnection to the new asyncio implementation.
Missing tests for now. Fix #1480.
1 parent 6ffb6b0 commit d8ab09b

File tree

5 files changed

+263
-23
lines changed

5 files changed

+263
-23
lines changed

docs/howto/upgrade.rst

+20-13
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,6 @@ Following redirects
7979
The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP
8080
redirects yet.
8181

82-
Automatic reconnection
83-
......................
84-
85-
The new implementation of :func:`~asyncio.client.connect` doesn't provide
86-
automatic reconnection yet.
87-
88-
In other words, the following pattern isn't supported::
89-
90-
from websockets.asyncio.client import connect
91-
92-
async for websocket in connect(...): # this doesn't work yet
93-
...
94-
9582
.. _Update import paths:
9683

9784
Import paths
@@ -185,6 +172,26 @@ it simpler.
185172
``process_response`` replaces ``extra_headers`` and provides more flexibility.
186173
See process_request_, select_subprotocol_, and process_response_ below.
187174

175+
Customizing automatic reconnection
176+
..................................
177+
178+
On the client side, if you're reconnecting automatically with ``async for ... in
179+
connect(...)``, the behavior when a connection attempt fails was enhanced and
180+
made configurable.
181+
182+
The original implementation retried on any error. The new implementation uses an
183+
heuristic to determine whether an error is retryable or fatal. By default, only
184+
network errors and server errors (HTTP 500, 502, 503, or 504) are considered
185+
retryable. You can customize this behavior with the ``process_exception``
186+
argument of :func:`~asyncio.client.connect`.
187+
188+
See :func:`~asyncio.client.process_exception` for more information.
189+
190+
Here's how to revert to the behavior of the original implementation::
191+
192+
async for ... in connect(..., process_exception=lambda exc: exc):
193+
...
194+
188195
Tracking open connections
189196
.........................
190197

docs/project/changelog.rst

+6-2
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@ Backwards-incompatible changes
4646
New features
4747
............
4848

49-
* Made the set of active connections available in the :attr:`Server.connections
50-
<asyncio.server.Server.connections>` property.
49+
* Added support for reconnecting automatically by using
50+
:func:`~asyncio.client.connect` as an asynchronous iterator to the new
51+
:mod:`asyncio` implementation.
5152

5253
* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
5354
implementations of servers.
5455

56+
* Made the set of active connections available in the :attr:`Server.connections
57+
<asyncio.server.Server.connections>` property.
58+
5559
.. _13.0:
5660

5761
13.0

docs/reference/asyncio/client.rst

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Opening a connection
1212
.. autofunction:: unix_connect
1313
:async:
1414

15+
.. autofunction:: process_exception
16+
1517
Using a connection
1618
------------------
1719

src/websockets/asyncio/client.py

+117-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import functools
5+
import logging
46
from types import TracebackType
5-
from typing import Any, Generator, Sequence
7+
from typing import Any, AsyncIterator, Callable, Generator, Sequence
68

7-
from ..client import ClientProtocol
9+
from ..client import ClientProtocol, backoff
810
from ..datastructures import HeadersLike
11+
from ..exceptions import InvalidStatus
912
from ..extensions.base import ClientExtensionFactory
1013
from ..extensions.permessage_deflate import enable_client_permessage_deflate
1114
from ..headers import validate_subprotocols
@@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
121124
self.response_rcvd.set_result(None)
122125

123126

127+
def process_exception(exc: Exception) -> Exception | None:
128+
"""
129+
Determine whether an error is retryable or fatal.
130+
131+
When reconnecting automatically with ``async for ... in connect(...)``, if a
132+
connection attempt fails, :func:`process_exception` is called to determine
133+
whether to retry connecting or to raise the exception.
134+
135+
This function defines the default behavior, which is to retry on:
136+
137+
* :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
138+
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
139+
502, 503, or 504: server or proxy errors.
140+
141+
All other exceptions are considered fatal.
142+
143+
You can change this behavior with the ``process_exception`` argument of
144+
:func:`connect`.
145+
146+
Return :obj:`None` if the exception is retryable i.e. when the error could
147+
be transient and trying to reconnect with the same parameters could succeed.
148+
The exception will be logged at the ``INFO`` level.
149+
150+
Return an exception, either ``exc`` or a new exception, if the exception is
151+
fatal i.e. when trying to reconnect will most likely produce the same error.
152+
That exception will be raised, breaking out of the retry loop.
153+
154+
"""
155+
if isinstance(exc, (OSError, asyncio.TimeoutError)):
156+
return None
157+
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
158+
500, # Internal Server Error
159+
502, # Bad Gateway
160+
503, # Service Unavailable
161+
504, # Gateway Timeout
162+
]:
163+
return None
164+
return exc
165+
166+
124167
# This is spelled in lower case because it's exposed as a callable in the API.
125168
class connect:
126169
"""
@@ -138,6 +181,21 @@ class connect:
138181
139182
The connection is closed automatically when exiting the context.
140183
184+
:func:`connect` can be used as an infinite asynchronous iterator to
185+
reconnect automatically on errors::
186+
187+
async for websocket in connect(...):
188+
try:
189+
...
190+
except websockets.ConnectionClosed:
191+
continue
192+
193+
If the connection fails with a transient error, it is retried with
194+
exponential backoff. If it fails with a fatal error, the exception is
195+
raised, breaking out of the loop.
196+
197+
The connection is closed automatically after each iteration of the loop.
198+
141199
Args:
142200
uri: URI of the WebSocket server.
143201
origin: Value of the ``Origin`` header, for servers that require it.
@@ -153,6 +211,9 @@ class connect:
153211
compression: The "permessage-deflate" extension is enabled by default.
154212
Set ``compression`` to :obj:`None` to disable it. See the
155213
:doc:`compression guide <../../topics/compression>` for details.
214+
process_exception: When reconnecting automatically, tell whether an
215+
error is transient or fatal. The default behavior is defined by
216+
:func:`process_exception`. Refer to its documentation for details.
156217
open_timeout: Timeout for opening the connection in seconds.
157218
:obj:`None` disables the timeout.
158219
ping_interval: Interval between keepalive pings in seconds.
@@ -219,6 +280,7 @@ def __init__(
219280
additional_headers: HeadersLike | None = None,
220281
user_agent_header: str | None = USER_AGENT,
221282
compression: str | None = "deflate",
283+
process_exception: Callable[[Exception], Exception | None] = process_exception,
222284
# Timeouts
223285
open_timeout: float | None = 10,
224286
ping_interval: float | None = 20,
@@ -281,19 +343,26 @@ def factory() -> ClientConnection:
281343

282344
loop = asyncio.get_running_loop()
283345
if kwargs.pop("unix", False):
284-
self.create_connection = loop.create_unix_connection(factory, **kwargs)
346+
self.create_connection = functools.partial(
347+
loop.create_unix_connection, factory, **kwargs
348+
)
285349
else:
286350
if kwargs.get("sock") is None:
287351
kwargs.setdefault("host", wsuri.host)
288352
kwargs.setdefault("port", wsuri.port)
289-
self.create_connection = loop.create_connection(factory, **kwargs)
353+
self.create_connection = functools.partial(
354+
loop.create_connection, factory, **kwargs
355+
)
290356

291357
self.handshake_args = (
292358
additional_headers,
293359
user_agent_header,
294360
)
295-
361+
self.process_exception = process_exception
296362
self.open_timeout = open_timeout
363+
if logger is None:
364+
logger = logging.getLogger("websockets.client")
365+
self.logger = logger
297366

298367
# ... = await connect(...)
299368

@@ -304,7 +373,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
304373
async def __await_impl__(self) -> ClientConnection:
305374
try:
306375
async with asyncio_timeout(self.open_timeout):
307-
_transport, self.connection = await self.create_connection
376+
_transport, self.connection = await self.create_connection()
308377
try:
309378
await self.connection.handshake(*self.handshake_args)
310379
except (Exception, asyncio.CancelledError):
@@ -333,6 +402,48 @@ async def __aexit__(
333402
) -> None:
334403
await self.connection.close()
335404

405+
# async for ... in connect(...):
406+
407+
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
408+
delays: Generator[float, None, None] | None = None
409+
while True:
410+
try:
411+
async with self as protocol:
412+
yield protocol
413+
except Exception as exc:
414+
# Determine whether the exception is retryable or fatal.
415+
# The API of process_exception is "return an exception or None";
416+
# "raise an exception" is also supported because it's a frequent
417+
# mistake. It isn't documented in order to keep the API simple.
418+
try:
419+
new_exc = self.process_exception(exc)
420+
except Exception as raised_exc:
421+
new_exc = raised_exc
422+
423+
# The connection failed with a fatal error.
424+
# Raise the exception and exit the loop.
425+
if new_exc is exc:
426+
raise
427+
if new_exc is not None:
428+
raise new_exc from exc
429+
430+
# The connection failed with a retryable error.
431+
# Start or continue backoff and reconnect.
432+
if delays is None:
433+
delays = backoff()
434+
delay = next(delays)
435+
self.logger.info(
436+
"! connect failed; reconnecting in %.1f seconds",
437+
delay,
438+
exc_info=True,
439+
)
440+
await asyncio.sleep(delay)
441+
continue
442+
443+
else:
444+
# The connection succeeded. Reset backoff.
445+
delays = None
446+
336447

337448
def unix_connect(
338449
path: str | None = None,

0 commit comments

Comments
 (0)