Skip to content

Commit a418075

Browse files
committed
[WIP] Add automatic reconnection to the new asyncio implementation.
Missing tests for now. Fix #1480.
1 parent c3ff254 commit a418075

File tree

5 files changed

+259
-23
lines changed

5 files changed

+259
-23
lines changed

docs/howto/upgrade.rst

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,6 @@ Following redirects
7979
The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP
8080
redirects yet.
8181

82-
Automatic reconnection
83-
......................
84-
85-
The new implementation of :func:`~asyncio.client.connect` doesn't provide
86-
automatic reconnection yet.
87-
88-
In other words, the following pattern isn't supported::
89-
90-
from websockets.asyncio.client import connect
91-
92-
async for websocket in connect(...): # this doesn't work yet
93-
...
94-
9582
.. _Update import paths:
9683

9784
Import paths
@@ -185,6 +172,29 @@ it simpler.
185172
``process_response`` replaces ``extra_headers`` and provides more flexibility.
186173
See process_request_, select_subprotocol_, and process_response_ below.
187174

175+
Customizing automatic reconnection
176+
..................................
177+
178+
On the client side, if you're reconnecting automatically with ``async for ... in
179+
connect(...)``, the behavior when a connection attempt fails was enhanced and
180+
made configurable.
181+
182+
The original implementation retried on any error. The new implementation uses an
183+
heuristic to determine whether an error is retriable or fatal. By default, only
184+
network errors and servers errors are considered retriable. You can customize
185+
this behavior with the ``process_exception`` argument of
186+
:func:`~asyncio.client.connect`.
187+
188+
See :func:`~asyncio.client.process_exception` for more information.
189+
190+
Here's how to revert to the behavior of the original implementation::
191+
192+
def process_exception(exc):
193+
return exc
194+
195+
async for ... in connect(..., process_exception=process_exception):
196+
...
197+
188198
Tracking open connections
189199
.........................
190200

docs/project/changelog.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@ Backwards-incompatible changes
4646
New features
4747
............
4848

49-
* Made the set of active connections available in the :attr:`Server.connections
50-
<asyncio.server.Server.connections>` property.
49+
* Added support for reconnecting automatically by using
50+
:func:`~asyncio.client.connect` as an asynchronous iterator to the new
51+
:mod:`asyncio` implementation.
5152

5253
* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
5354
implementations of servers.
5455

56+
* Made the set of active connections available in the :attr:`Server.connections
57+
<asyncio.server.Server.connections>` property.
58+
5559
.. _13.0:
5660

5761
13.0

docs/reference/asyncio/client.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Opening a connection
1212
.. autofunction:: unix_connect
1313
:async:
1414

15+
.. autofunction:: process_exception
16+
1517
Using a connection
1618
------------------
1719

src/websockets/asyncio/client.py

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import functools
5+
import logging
46
from types import TracebackType
5-
from typing import Any, Generator, Sequence
7+
from typing import Any, AsyncIterator, Callable, Generator, Sequence
68

7-
from ..client import ClientProtocol
9+
from ..client import ClientProtocol, backoff
810
from ..datastructures import HeadersLike
11+
from ..exceptions import InvalidStatus
912
from ..extensions.base import ClientExtensionFactory
1013
from ..extensions.permessage_deflate import enable_client_permessage_deflate
1114
from ..headers import validate_subprotocols
@@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
121124
self.response_rcvd.set_result(None)
122125

123126

127+
def process_exception(exc: Exception) -> Exception | None:
128+
"""
129+
Determine whether an error is retriable or fatal.
130+
131+
When reconnecting automatically with ``async for ... in connect(...)``, if a
132+
connection attempt fails, :func:`process_exception` is called to determine
133+
whether to retry connecting or to raise the exception.
134+
135+
This function defines the default behavior, which is to retry on:
136+
137+
* :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
138+
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
139+
502, 503, or 504: server or proxy errors.
140+
141+
All other exceptions are considered fatal.
142+
143+
You can change this behavior with the ``process_exception`` argument of
144+
:func:`connect`.
145+
146+
Return :obj:`None` if the exception is retriable i.e. when the error could
147+
be transient and trying to reconnect with the same parameters could succeed.
148+
The exception will be logged at the ``INFO`` level.
149+
150+
Return an exception, either ``exc`` or a new exception, if the exception is
151+
fatal i.e. when trying to reconnect will most likely produce the same error.
152+
That exception will be raised, breaking out of the retry loop.
153+
154+
"""
155+
if isinstance(exc, (OSError, asyncio.TimeoutError)):
156+
return None
157+
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
158+
500, # Internal Server Error
159+
502, # Bad Gateway
160+
503, # Service Unavailable
161+
504, # Gateway Timeout
162+
]:
163+
return None
164+
return exc
165+
166+
124167
# This is spelled in lower case because it's exposed as a callable in the API.
125168
class connect:
126169
"""
@@ -138,6 +181,21 @@ class connect:
138181
139182
The connection is closed automatically when exiting the context.
140183
184+
:func:`connect` can be used as an infinite asynchronous iterator to
185+
reconnect automatically on errors::
186+
187+
async for websocket in connect(...):
188+
try:
189+
...
190+
except websockets.ConnectionClosed:
191+
continue
192+
193+
If the connection fails with a transient error, it is retried with
194+
exponential backoff. If it fails with a fatal error, the exception is
195+
raised, breaking out of the loop.
196+
197+
The connection is closed automatically after each iteration of the loop.
198+
141199
Args:
142200
uri: URI of the WebSocket server.
143201
origin: Value of the ``Origin`` header, for servers that require it.
@@ -153,6 +211,9 @@ class connect:
153211
compression: The "permessage-deflate" extension is enabled by default.
154212
Set ``compression`` to :obj:`None` to disable it. See the
155213
:doc:`compression guide <../../topics/compression>` for details.
214+
process_exception: When reconnecting automatically, tell whether an
215+
error is transient or fatal. The default behavior is defined by
216+
:func:`process_exception`. Refer to its documentation for details.
156217
open_timeout: Timeout for opening the connection in seconds.
157218
:obj:`None` disables the timeout.
158219
ping_interval: Interval between keepalive pings in seconds.
@@ -219,6 +280,7 @@ def __init__(
219280
additional_headers: HeadersLike | None = None,
220281
user_agent_header: str | None = USER_AGENT,
221282
compression: str | None = "deflate",
283+
process_exception: Callable[[Exception], Exception | None] = process_exception,
222284
# Timeouts
223285
open_timeout: float | None = 10,
224286
ping_interval: float | None = 20,
@@ -281,19 +343,29 @@ def factory() -> ClientConnection:
281343

282344
loop = asyncio.get_running_loop()
283345
if kwargs.pop("unix", False):
284-
self.create_connection = loop.create_unix_connection(factory, **kwargs)
346+
self.create_connection = functools.partial(
347+
loop.create_unix_connection, factory, **kwargs
348+
)
285349
else:
286350
if kwargs.get("sock") is None:
287351
kwargs.setdefault("host", wsuri.host)
288352
kwargs.setdefault("port", wsuri.port)
289-
self.create_connection = loop.create_connection(factory, **kwargs)
353+
self.create_connection = functools.partial(
354+
loop.create_connection, factory, **kwargs
355+
)
290356

291357
self.handshake_args = (
292358
additional_headers,
293359
user_agent_header,
294360
)
295-
361+
self.process_exception = process_exception
296362
self.open_timeout = open_timeout
363+
self.logger: LoggerLike
364+
# TODO - I'm not happy with this
365+
if logger is None:
366+
self.logger = logging.getLogger("websockets.client")
367+
else: # pragma: no cover
368+
self.logger = logger
297369

298370
# ... = await connect(...)
299371

@@ -304,7 +376,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
304376
async def __await_impl__(self) -> ClientConnection:
305377
try:
306378
async with asyncio_timeout(self.open_timeout):
307-
_transport, self.connection = await self.create_connection
379+
_transport, self.connection = await self.create_connection()
308380
try:
309381
await self.connection.handshake(*self.handshake_args)
310382
except (Exception, asyncio.CancelledError):
@@ -333,6 +405,48 @@ async def __aexit__(
333405
) -> None:
334406
await self.connection.close()
335407

408+
# async for ... in connect(...):
409+
410+
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
411+
delays: Generator[float, None, None] | None = None
412+
while True:
413+
try:
414+
async with self as protocol:
415+
yield protocol
416+
except Exception as exc:
417+
# Determine whether the exception is retriable or fatal.
418+
# The API of process_exception is "return an exception or None";
419+
# "raise an exception" is also supported because it's a frequent
420+
# mistake. It isn't documented in order to keep the API simple.
421+
try:
422+
new_exc = self.process_exception(exc)
423+
except Exception as raised_exc:
424+
new_exc = raised_exc
425+
426+
# The connection failed with a fatal error.
427+
# Raise the exception and exit the loop.
428+
if new_exc is exc:
429+
raise
430+
if new_exc is not None:
431+
raise new_exc from exc
432+
433+
# The connection failed with a retriable error.
434+
# Start or continue backoff and reconnect.
435+
if delays is None:
436+
delays = backoff()
437+
delay = next(delays)
438+
self.logger.info(
439+
"! connect failed; reconnecting in %.1f seconds",
440+
delay,
441+
exc_info=True,
442+
)
443+
await asyncio.sleep(delay)
444+
continue
445+
446+
else:
447+
# The connection succeeded. Reset backoff.
448+
delays = None
449+
336450

337451
def unix_connect(
338452
path: str | None = None,

0 commit comments

Comments
 (0)