Skip to content

Commit e7573a8

Browse files
committed
[WIP] Add automatic reconnection to the new asyncio implementation.
Missing tests for now. Fix #1480.
1 parent 51a31af commit e7573a8

File tree

4 files changed

+151
-31
lines changed

4 files changed

+151
-31
lines changed

docs/howto/upgrade.rst

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,6 @@ Following redirects
7979
The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP
8080
redirects yet.
8181

82-
Automatic reconnection
83-
......................
84-
85-
The new implementation of :func:`~asyncio.client.connect` doesn't provide
86-
automatic reconnection yet.
87-
88-
In other words, the following pattern isn't supported::
89-
90-
from websockets.asyncio.client import connect
91-
92-
async for websocket in connect(...): # this doesn't work yet
93-
...
94-
9582
.. _Update import paths:
9683

9784
Import paths
@@ -185,6 +172,29 @@ it simpler.
185172
``process_response`` replaces ``extra_headers`` and provides more flexibility.
186173
See process_request_, select_subprotocol_, and process_response_ below.
187174

175+
Customizing automatic reconnection
176+
..................................
177+
178+
On the client side, if you're reconnecting automatically with ``async for ... in
179+
connect(...)``, the behavior when a connection attempt fails was enhanced and
180+
made configurable.
181+
182+
The original implementation retried on all errors. The new implementation
183+
provides an heuristic to determine whether an error is retryable or fatal. By
184+
default, only network errors and servers errors are considered retryable. You
185+
can customize this behavior with the ``process_exception`` argument of
186+
:func:`~asyncio.client.connect`.
187+
188+
See :func:`~asyncio.client.process_exception` for more information.
189+
190+
Here's how to revert to the behavior of the original implementation::
191+
192+
def process_exception(exc):
193+
return exc
194+
195+
async for ... in connect(..., process_exception=process_exception):
196+
...
197+
188198
Tracking open connections
189199
.........................
190200

docs/project/changelog.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,17 @@ Backwards-incompatible changes
4646
New features
4747
............
4848

49-
* Made the set of active connections available in the :attr:`Server.connections
50-
<asyncio.server.Server.connections>` property.
49+
* Added support for reconnecting automatically by using
50+
:func:`~legacy.asyncio.connect` as an asynchronous iterator to the new
51+
:mod:`asyncio` implementation.
5152

5253
* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
5354
implementations of servers.
5455

56+
* Made the set of active connections available in the :attr:`Server.connections
57+
<asyncio.server.Server.connections>` property.
58+
59+
5560
.. _13.0:
5661

5762
13.0

docs/reference/asyncio/client.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Opening a connection
1212
.. autofunction:: unix_connect
1313
:async:
1414

15+
.. autofunction:: process_exception
16+
1517
Using a connection
1618
------------------
1719

src/websockets/asyncio/client.py

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import functools
5+
import logging
46
from types import TracebackType
5-
from typing import Any, Generator, Sequence
7+
from typing import Any, AsyncIterator, Callable, Generator, Sequence
68

7-
from ..client import ClientProtocol
9+
from ..client import ClientProtocol, backoff
810
from ..datastructures import HeadersLike
11+
from ..exceptions import InvalidStatus
912
from ..extensions.base import ClientExtensionFactory
1013
from ..extensions.permessage_deflate import enable_client_permessage_deflate
1114
from ..headers import validate_subprotocols
@@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
121124
self.response_rcvd.set_result(None)
122125

123126

127+
def process_exception(exc: Exception) -> Exception | None:
128+
"""
129+
Determine whether an error is retryable or fatal.
130+
131+
When reconnecting automatically with ``async for ... in connect(...)``, if a
132+
connection attempt fails, :func:`process_exception` is called to determine
133+
whether to retry connecting or to raise the exception.
134+
135+
This function defines the default behavior, which is to retry on:
136+
137+
* :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
138+
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
139+
502, 503, or 504: server or proxy errors.
140+
141+
All other exceptions are considered fatal.
142+
143+
You can change this behavior with the ``process_exception`` argument of
144+
:func:`connect`.
145+
146+
Return :obj:`None` if the exception is retryable i.e. when the error could
147+
be transient and trying to reconnect with the same parameters could succeed.
148+
The exception will be logged at the ``INFO`` level.
149+
150+
Return an exception, either ``exc`` or a new exception, if the exception is
151+
fatal i.e. when trying to reconnect will most likely produce the same error.
152+
That exception will be raised, breaking out of the retry loop.
153+
154+
"""
155+
if isinstance(exc, (OSError, asyncio.TimeoutError)):
156+
return None
157+
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
158+
500, # Internal Server Error
159+
502, # Bad Gateway
160+
503, # Service Unavailable
161+
504, # Gateway Timeout
162+
]:
163+
return None
164+
return exc
165+
166+
124167
# This is spelled in lower case because it's exposed as a callable in the API.
125168
class connect:
126169
"""
@@ -138,21 +181,39 @@ class connect:
138181
139182
The connection is closed automatically when exiting the context.
140183
184+
:func:`connect` can be used as an infinite asynchronous iterator to
185+
reconnect automatically on errors::
186+
187+
async for websocket in connect(...):
188+
try:
189+
...
190+
except websockets.ConnectionClosed:
191+
continue
192+
193+
If the connection fails with a transient error, it is retried with
194+
exponential backoff. If it fails with a fatal error, the exception is
195+
raised, breaking out of the loop.
196+
197+
The connection is closed automatically after each iteration of the loop.
198+
141199
Args:
142-
uri: URI of the WebSocket server.
143-
origin: Value of the ``Origin`` header, for servers that require it.
144-
extensions: List of supported extensions, in order in which they
200+
uri: URI of the WebSocket server. origin: Value of the ``Origin``
201+
header, for servers that require it. extensions: List of supported
202+
extensions, in order in which they
145203
should be negotiated and run.
146204
subprotocols: List of supported subprotocols, in order of decreasing
147205
preference.
148206
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
149207
to the handshake request.
150208
user_agent_header: Value of the ``User-Agent`` request header.
151-
It defaults to ``"Python/x.y.z websockets/X.Y"``.
152-
Setting it to :obj:`None` removes the header.
209+
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
210+
:obj:`None` removes the header.
153211
compression: The "permessage-deflate" extension is enabled by default.
154212
Set ``compression`` to :obj:`None` to disable it. See the
155213
:doc:`compression guide <../../topics/compression>` for details.
214+
process_exception: When reconnecting automatically, tell whether an
215+
error is transient or fatal. The default behavior is defined by
216+
:func:`process_exception`. Refer to its documentation for details.
156217
open_timeout: Timeout for opening the connection in seconds.
157218
:obj:`None` disables the timeout.
158219
ping_interval: Interval between keepalive pings in seconds.
@@ -172,8 +233,8 @@ class connect:
172233
to 32 KiB. You may pass a ``(high, low)`` tuple to set the
173234
high-water and low-water marks.
174235
logger: Logger for this client.
175-
It defaults to ``logging.getLogger("websockets.client")``.
176-
See the :doc:`logging guide <../../topics/logging>` for details.
236+
It defaults to ``logging.getLogger("websockets.client")``. See the
237+
:doc:`logging guide <../../topics/logging>` for details.
177238
create_connection: Factory for the :class:`ClientConnection` managing
178239
the connection. Set it to a wrapper or a subclass to customize
179240
connection handling.
@@ -201,9 +262,8 @@ class connect:
201262
client socket and customize it.
202263
203264
Raises:
204-
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
205-
OSError: If the TCP connection fails.
206-
InvalidHandshake: If the opening handshake fails.
265+
InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP
266+
connection fails. InvalidHandshake: If the opening handshake fails.
207267
TimeoutError: If the opening handshake times out.
208268
209269
"""
@@ -219,6 +279,7 @@ def __init__(
219279
additional_headers: HeadersLike | None = None,
220280
user_agent_header: str | None = USER_AGENT,
221281
compression: str | None = "deflate",
282+
process_exception: Callable[[Exception], Exception | None] = process_exception,
222283
# Timeouts
223284
open_timeout: float | None = 10,
224285
ping_interval: float | None = 20,
@@ -281,19 +342,28 @@ def factory() -> ClientConnection:
281342

282343
loop = asyncio.get_running_loop()
283344
if kwargs.pop("unix", False):
284-
self.create_connection = loop.create_unix_connection(factory, **kwargs)
345+
self.create_connection = functools.partial(
346+
loop.create_unix_connection, factory, **kwargs
347+
)
285348
else:
286349
if kwargs.get("sock") is None:
287350
kwargs.setdefault("host", wsuri.host)
288351
kwargs.setdefault("port", wsuri.port)
289-
self.create_connection = loop.create_connection(factory, **kwargs)
352+
self.create_connection = functools.partial(
353+
loop.create_connection, factory, **kwargs
354+
)
290355

291356
self.handshake_args = (
292357
additional_headers,
293358
user_agent_header,
294359
)
295-
360+
self.process_exception = process_exception
296361
self.open_timeout = open_timeout
362+
self.logger: LoggerLike
363+
if logger is None:
364+
self.logger = logging.getLogger("websockets.client")
365+
else:
366+
self.logger = logger
297367

298368
# ... = await connect(...)
299369

@@ -304,7 +374,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
304374
async def __await_impl__(self) -> ClientConnection:
305375
try:
306376
async with asyncio_timeout(self.open_timeout):
307-
_transport, self.connection = await self.create_connection
377+
_transport, self.connection = await self.create_connection()
308378
try:
309379
await self.connection.handshake(*self.handshake_args)
310380
except (Exception, asyncio.CancelledError):
@@ -333,6 +403,39 @@ async def __aexit__(
333403
) -> None:
334404
await self.connection.close()
335405

406+
# async for ... in connect(...):
407+
408+
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
409+
delays: Generator[float, None, None] | None = None
410+
while True:
411+
try:
412+
async with self as protocol:
413+
yield protocol
414+
except Exception as exc:
415+
# Exit the loop if the error isn't retryable.
416+
new_exc = self.process_exception(exc)
417+
if new_exc is exc:
418+
raise
419+
if new_exc is not None:
420+
raise new_exc from exc
421+
422+
# The connection failed with a retryable error.
423+
# Start or continue backoff and reconnect.
424+
if delays is None:
425+
delays = backoff()
426+
delay = next(delays)
427+
self.logger.info(
428+
"! connect failed; reconnecting in %.1f seconds",
429+
delay,
430+
exc_info=True,
431+
)
432+
await asyncio.sleep(delay)
433+
continue
434+
435+
else:
436+
# The connection succeeded. Reset backoff.
437+
delays = None
438+
336439

337440
def unix_connect(
338441
path: str | None = None,

0 commit comments

Comments
 (0)