1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import functools
5
+ import logging
4
6
from types import TracebackType
5
- from typing import Any , Generator , Sequence
7
+ from typing import Any , AsyncIterator , Callable , Generator , Sequence
6
8
7
- from ..client import ClientProtocol
9
+ from ..client import ClientProtocol , backoff
8
10
from ..datastructures import HeadersLike
11
+ from ..exceptions import InvalidStatus
9
12
from ..extensions .base import ClientExtensionFactory
10
13
from ..extensions .permessage_deflate import enable_client_permessage_deflate
11
14
from ..headers import validate_subprotocols
@@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
121
124
self .response_rcvd .set_result (None )
122
125
123
126
127
+ def process_exception (exc : Exception ) -> Exception | None :
128
+ """
129
+ Determine whether an error is retriable or fatal.
130
+
131
+ When reconnecting automatically with ``async for ... in connect(...)``, if a
132
+ connection attempt fails, :func:`process_exception` is called to determine
133
+ whether to retry connecting or to raise the exception.
134
+
135
+ This function defines the default behavior, which is to retry on:
136
+
137
+ * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
138
+ * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
139
+ 502, 503, or 504: server or proxy errors.
140
+
141
+ All other exceptions are considered fatal.
142
+
143
+ You can change this behavior with the ``process_exception`` argument of
144
+ :func:`connect`.
145
+
146
+ Return :obj:`None` if the exception is retriable i.e. when the error could
147
+ be transient and trying to reconnect with the same parameters could succeed.
148
+ The exception will be logged at the ``INFO`` level.
149
+
150
+ Return an exception, either ``exc`` or a new exception, if the exception is
151
+ fatal i.e. when trying to reconnect will most likely produce the same error.
152
+ That exception will be raised, breaking out of the retry loop.
153
+
154
+ """
155
+ if isinstance (exc , (OSError , asyncio .TimeoutError )):
156
+ return None
157
+ if isinstance (exc , InvalidStatus ) and exc .response .status_code in [
158
+ 500 , # Internal Server Error
159
+ 502 , # Bad Gateway
160
+ 503 , # Service Unavailable
161
+ 504 , # Gateway Timeout
162
+ ]:
163
+ return None
164
+ return exc
165
+
166
+
124
167
# This is spelled in lower case because it's exposed as a callable in the API.
125
168
class connect :
126
169
"""
@@ -138,6 +181,21 @@ class connect:
138
181
139
182
The connection is closed automatically when exiting the context.
140
183
184
+ :func:`connect` can be used as an infinite asynchronous iterator to
185
+ reconnect automatically on errors::
186
+
187
+ async for websocket in connect(...):
188
+ try:
189
+ ...
190
+ except websockets.ConnectionClosed:
191
+ continue
192
+
193
+ If the connection fails with a transient error, it is retried with
194
+ exponential backoff. If it fails with a fatal error, the exception is
195
+ raised, breaking out of the loop.
196
+
197
+ The connection is closed automatically after each iteration of the loop.
198
+
141
199
Args:
142
200
uri: URI of the WebSocket server.
143
201
origin: Value of the ``Origin`` header, for servers that require it.
@@ -153,6 +211,9 @@ class connect:
153
211
compression: The "permessage-deflate" extension is enabled by default.
154
212
Set ``compression`` to :obj:`None` to disable it. See the
155
213
:doc:`compression guide <../../topics/compression>` for details.
214
+ process_exception: When reconnecting automatically, tell whether an
215
+ error is transient or fatal. The default behavior is defined by
216
+ :func:`process_exception`. Refer to its documentation for details.
156
217
open_timeout: Timeout for opening the connection in seconds.
157
218
:obj:`None` disables the timeout.
158
219
ping_interval: Interval between keepalive pings in seconds.
@@ -219,6 +280,7 @@ def __init__(
219
280
additional_headers : HeadersLike | None = None ,
220
281
user_agent_header : str | None = USER_AGENT ,
221
282
compression : str | None = "deflate" ,
283
+ process_exception : Callable [[Exception ], Exception | None ] = process_exception ,
222
284
# Timeouts
223
285
open_timeout : float | None = 10 ,
224
286
ping_interval : float | None = 20 ,
@@ -281,19 +343,29 @@ def factory() -> ClientConnection:
281
343
282
344
loop = asyncio .get_running_loop ()
283
345
if kwargs .pop ("unix" , False ):
284
- self .create_connection = loop .create_unix_connection (factory , ** kwargs )
346
+ self .create_connection = functools .partial (
347
+ loop .create_unix_connection , factory , ** kwargs
348
+ )
285
349
else :
286
350
if kwargs .get ("sock" ) is None :
287
351
kwargs .setdefault ("host" , wsuri .host )
288
352
kwargs .setdefault ("port" , wsuri .port )
289
- self .create_connection = loop .create_connection (factory , ** kwargs )
353
+ self .create_connection = functools .partial (
354
+ loop .create_connection , factory , ** kwargs
355
+ )
290
356
291
357
self .handshake_args = (
292
358
additional_headers ,
293
359
user_agent_header ,
294
360
)
295
-
361
+ self . process_exception = process_exception
296
362
self .open_timeout = open_timeout
363
+ self .logger : LoggerLike
364
+ # TODO - I'm not happy with this
365
+ if logger is None :
366
+ self .logger = logging .getLogger ("websockets.client" )
367
+ else : # pragma: no cover
368
+ self .logger = logger
297
369
298
370
# ... = await connect(...)
299
371
@@ -304,7 +376,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
304
376
async def __await_impl__ (self ) -> ClientConnection :
305
377
try :
306
378
async with asyncio_timeout (self .open_timeout ):
307
- _transport , self .connection = await self .create_connection
379
+ _transport , self .connection = await self .create_connection ()
308
380
try :
309
381
await self .connection .handshake (* self .handshake_args )
310
382
except (Exception , asyncio .CancelledError ):
@@ -333,6 +405,48 @@ async def __aexit__(
333
405
) -> None :
334
406
await self .connection .close ()
335
407
408
+ # async for ... in connect(...):
409
+
410
+ async def __aiter__ (self ) -> AsyncIterator [ClientConnection ]:
411
+ delays : Generator [float , None , None ] | None = None
412
+ while True :
413
+ try :
414
+ async with self as protocol :
415
+ yield protocol
416
+ except Exception as exc :
417
+ # Determine whether the exception is retriable or fatal.
418
+ # The API of process_exception is "return an exception or None";
419
+ # "raise an exception" is also supported because it's a frequent
420
+ # mistake. It isn't documented in order to keep the API simple.
421
+ try :
422
+ new_exc = self .process_exception (exc )
423
+ except Exception as raised_exc :
424
+ new_exc = raised_exc
425
+
426
+ # The connection failed with a fatal error.
427
+ # Raise the exception and exit the loop.
428
+ if new_exc is exc :
429
+ raise
430
+ if new_exc is not None :
431
+ raise new_exc from exc
432
+
433
+ # The connection failed with a retriable error.
434
+ # Start or continue backoff and reconnect.
435
+ if delays is None :
436
+ delays = backoff ()
437
+ delay = next (delays )
438
+ self .logger .info (
439
+ "! connect failed; reconnecting in %.1f seconds" ,
440
+ delay ,
441
+ exc_info = True ,
442
+ )
443
+ await asyncio .sleep (delay )
444
+ continue
445
+
446
+ else :
447
+ # The connection succeeded. Reset backoff.
448
+ delays = None
449
+
336
450
337
451
def unix_connect (
338
452
path : str | None = None ,
0 commit comments