1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import functools
5
+ import logging
4
6
from types import TracebackType
5
- from typing import Any , Generator , Sequence
7
+ from typing import Any , AsyncIterator , Callable , Generator , Sequence
6
8
7
- from ..client import ClientProtocol
9
+ from ..client import ClientProtocol , backoff
8
10
from ..datastructures import HeadersLike
11
+ from ..exceptions import InvalidStatus
9
12
from ..extensions .base import ClientExtensionFactory
10
13
from ..extensions .permessage_deflate import enable_client_permessage_deflate
11
14
from ..headers import validate_subprotocols
@@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
121
124
self .response_rcvd .set_result (None )
122
125
123
126
127
+ def process_exception (exc : Exception ) -> Exception | None :
128
+ """
129
+ Determine whether an error is retryable or fatal.
130
+
131
+ When reconnecting automatically with ``async for ... in connect(...)``, if a
132
+ connection attempt fails, :func:`process_exception` is called to determine
133
+ whether to retry connecting or to raise the exception.
134
+
135
+ This function defines the default behavior, which is to retry on:
136
+
137
+ * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
138
+ * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
139
+ 502, 503, or 504: server or proxy errors.
140
+
141
+ All other exceptions are considered fatal.
142
+
143
+ You can change this behavior with the ``process_exception`` argument of
144
+ :func:`connect`.
145
+
146
+ Return :obj:`None` if the exception is retryable i.e. when the error could
147
+ be transient and trying to reconnect with the same parameters could succeed.
148
+ The exception will be logged at the ``INFO`` level.
149
+
150
+ Return an exception, either ``exc`` or a new exception, if the exception is
151
+ fatal i.e. when trying to reconnect will most likely produce the same error.
152
+ That exception will be raised, breaking out of the retry loop.
153
+
154
+ """
155
+ if isinstance (exc , (OSError , asyncio .TimeoutError )):
156
+ return None
157
+ if isinstance (exc , InvalidStatus ) and exc .response .status_code in [
158
+ 500 , # Internal Server Error
159
+ 502 , # Bad Gateway
160
+ 503 , # Service Unavailable
161
+ 504 , # Gateway Timeout
162
+ ]:
163
+ return None
164
+ return exc
165
+
166
+
124
167
# This is spelled in lower case because it's exposed as a callable in the API.
125
168
class connect :
126
169
"""
@@ -138,21 +181,39 @@ class connect:
138
181
139
182
The connection is closed automatically when exiting the context.
140
183
184
+ :func:`connect` can be used as an infinite asynchronous iterator to
185
+ reconnect automatically on errors::
186
+
187
+ async for websocket in connect(...):
188
+ try:
189
+ ...
190
+ except websockets.ConnectionClosed:
191
+ continue
192
+
193
+ If the connection fails with a transient error, it is retried with
194
+ exponential backoff. If it fails with a fatal error, the exception is
195
+ raised, breaking out of the loop.
196
+
197
+ The connection is closed automatically after each iteration of the loop.
198
+
141
199
Args:
142
- uri: URI of the WebSocket server.
143
- origin: Value of the ``Origin`` header, for servers that require it.
144
- extensions: List of supported extensions , in order in which they
200
+ uri: URI of the WebSocket server. origin: Value of the ``Origin``
201
+ header, for servers that require it. extensions: List of supported
202
+ extensions, in order in which they
145
203
should be negotiated and run.
146
204
subprotocols: List of supported subprotocols, in order of decreasing
147
205
preference.
148
206
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
149
207
to the handshake request.
150
208
user_agent_header: Value of the ``User-Agent`` request header.
151
- It defaults to ``"Python/x.y.z websockets/X.Y"``.
152
- Setting it to :obj:`None` removes the header.
209
+ It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
210
+ :obj:`None` removes the header.
153
211
compression: The "permessage-deflate" extension is enabled by default.
154
212
Set ``compression`` to :obj:`None` to disable it. See the
155
213
:doc:`compression guide <../../topics/compression>` for details.
214
+ process_exception: When reconnecting automatically, tell whether an
215
+ error is transient or fatal. The default behavior is defined by
216
+ :func:`process_exception`. Refer to its documentation for details.
156
217
open_timeout: Timeout for opening the connection in seconds.
157
218
:obj:`None` disables the timeout.
158
219
ping_interval: Interval between keepalive pings in seconds.
@@ -172,8 +233,8 @@ class connect:
172
233
to 32 KiB. You may pass a ``(high, low)`` tuple to set the
173
234
high-water and low-water marks.
174
235
logger: Logger for this client.
175
- It defaults to ``logging.getLogger("websockets.client")``.
176
- See the :doc:`logging guide <../../topics/logging>` for details.
236
+ It defaults to ``logging.getLogger("websockets.client")``. See the
237
+ :doc:`logging guide <../../topics/logging>` for details.
177
238
create_connection: Factory for the :class:`ClientConnection` managing
178
239
the connection. Set it to a wrapper or a subclass to customize
179
240
connection handling.
@@ -201,9 +262,8 @@ class connect:
201
262
client socket and customize it.
202
263
203
264
Raises:
204
- InvalidURI: If ``uri`` isn't a valid WebSocket URI.
205
- OSError: If the TCP connection fails.
206
- InvalidHandshake: If the opening handshake fails.
265
+ InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP
266
+ connection fails. InvalidHandshake: If the opening handshake fails.
207
267
TimeoutError: If the opening handshake times out.
208
268
209
269
"""
@@ -219,6 +279,7 @@ def __init__(
219
279
additional_headers : HeadersLike | None = None ,
220
280
user_agent_header : str | None = USER_AGENT ,
221
281
compression : str | None = "deflate" ,
282
+ process_exception : Callable [[Exception ], Exception | None ] = process_exception ,
222
283
# Timeouts
223
284
open_timeout : float | None = 10 ,
224
285
ping_interval : float | None = 20 ,
@@ -281,19 +342,28 @@ def factory() -> ClientConnection:
281
342
282
343
loop = asyncio .get_running_loop ()
283
344
if kwargs .pop ("unix" , False ):
284
- self .create_connection = loop .create_unix_connection (factory , ** kwargs )
345
+ self .create_connection = functools .partial (
346
+ loop .create_unix_connection , factory , ** kwargs
347
+ )
285
348
else :
286
349
if kwargs .get ("sock" ) is None :
287
350
kwargs .setdefault ("host" , wsuri .host )
288
351
kwargs .setdefault ("port" , wsuri .port )
289
- self .create_connection = loop .create_connection (factory , ** kwargs )
352
+ self .create_connection = functools .partial (
353
+ loop .create_connection , factory , ** kwargs
354
+ )
290
355
291
356
self .handshake_args = (
292
357
additional_headers ,
293
358
user_agent_header ,
294
359
)
295
-
360
+ self . process_exception = process_exception
296
361
self .open_timeout = open_timeout
362
+ self .logger : LoggerLike
363
+ if logger is None :
364
+ self .logger = logging .getLogger ("websockets.client" )
365
+ else :
366
+ self .logger = logger
297
367
298
368
# ... = await connect(...)
299
369
@@ -304,7 +374,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
304
374
async def __await_impl__ (self ) -> ClientConnection :
305
375
try :
306
376
async with asyncio_timeout (self .open_timeout ):
307
- _transport , self .connection = await self .create_connection
377
+ _transport , self .connection = await self .create_connection ()
308
378
try :
309
379
await self .connection .handshake (* self .handshake_args )
310
380
except (Exception , asyncio .CancelledError ):
@@ -333,6 +403,39 @@ async def __aexit__(
333
403
) -> None :
334
404
await self .connection .close ()
335
405
406
+ # async for ... in connect(...):
407
+
408
+ async def __aiter__ (self ) -> AsyncIterator [ClientConnection ]:
409
+ delays : Generator [float , None , None ] | None = None
410
+ while True :
411
+ try :
412
+ async with self as protocol :
413
+ yield protocol
414
+ except Exception as exc :
415
+ # Exit the loop if the error isn't retryable.
416
+ new_exc = self .process_exception (exc )
417
+ if new_exc is exc :
418
+ raise
419
+ if new_exc is not None :
420
+ raise new_exc from exc
421
+
422
+ # The connection failed with a retryable error.
423
+ # Start or continue backoff and reconnect.
424
+ if delays is None :
425
+ delays = backoff ()
426
+ delay = next (delays )
427
+ self .logger .info (
428
+ "! connect failed; reconnecting in %.1f seconds" ,
429
+ delay ,
430
+ exc_info = True ,
431
+ )
432
+ await asyncio .sleep (delay )
433
+ continue
434
+
435
+ else :
436
+ # The connection succeeded. Reset backoff.
437
+ delays = None
438
+
336
439
337
440
def unix_connect (
338
441
path : str | None = None ,
0 commit comments