1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import functools
5
+ import logging
4
6
from types import TracebackType
5
- from typing import Any , Generator , Sequence
7
+ from typing import Any , AsyncIterator , Callable , Generator , Sequence
6
8
7
- from ..client import ClientProtocol
9
+ from ..client import ClientProtocol , backoff
8
10
from ..datastructures import HeadersLike
11
+ from ..exceptions import InvalidStatus
9
12
from ..extensions .base import ClientExtensionFactory
10
13
from ..extensions .permessage_deflate import enable_client_permessage_deflate
11
14
from ..headers import validate_subprotocols
@@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
121
124
self .response_rcvd .set_result (None )
122
125
123
126
127
+ def process_exception (exc : Exception ) -> Exception | None :
128
+ """
129
+ Determine whether an error is retryable or fatal.
130
+
131
+ When reconnecting automatically with ``async for ... in connect(...)``, if a
132
+ connection attempt fails, :func:`process_exception` is called to determine
133
+ whether to retry connecting or to raise the exception.
134
+
135
+ This function defines the default behavior, which is to retry on:
136
+
137
+ * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
138
+ * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
139
+ 502, 503, or 504: server or proxy errors.
140
+
141
+ All other exceptions are considered fatal.
142
+
143
+ You can change this behavior with the ``process_exception`` argument of
144
+ :func:`connect`.
145
+
146
+ Return :obj:`None` if the exception is retryable i.e. when the error could
147
+ be transient and trying to reconnect with the same parameters could succeed.
148
+ The exception will be logged at the ``INFO`` level.
149
+
150
+ Return an exception, either ``exc`` or a new exception, if the exception is
151
+ fatal i.e. when trying to reconnect will most likely produce the same error.
152
+ That exception will be raised, breaking out of the retry loop.
153
+
154
+ """
155
+ if isinstance (exc , (OSError , asyncio .TimeoutError )):
156
+ return None
157
+ if isinstance (exc , InvalidStatus ) and exc .response .status_code in [
158
+ 500 , # Internal Server Error
159
+ 502 , # Bad Gateway
160
+ 503 , # Service Unavailable
161
+ 504 , # Gateway Timeout
162
+ ]:
163
+ return None
164
+ return exc
165
+
166
+
124
167
# This is spelled in lower case because it's exposed as a callable in the API.
125
168
class connect :
126
169
"""
@@ -138,6 +181,21 @@ class connect:
138
181
139
182
The connection is closed automatically when exiting the context.
140
183
184
+ :func:`connect` can be used as an infinite asynchronous iterator to
185
+ reconnect automatically on errors::
186
+
187
+ async for websocket in connect(...):
188
+ try:
189
+ ...
190
+ except websockets.ConnectionClosed:
191
+ continue
192
+
193
+ If the connection fails with a transient error, it is retried with
194
+ exponential backoff. If it fails with a fatal error, the exception is
195
+ raised, breaking out of the loop.
196
+
197
+ The connection is closed automatically after each iteration of the loop.
198
+
141
199
Args:
142
200
uri: URI of the WebSocket server.
143
201
origin: Value of the ``Origin`` header, for servers that require it.
@@ -153,6 +211,9 @@ class connect:
153
211
compression: The "permessage-deflate" extension is enabled by default.
154
212
Set ``compression`` to :obj:`None` to disable it. See the
155
213
:doc:`compression guide <../../topics/compression>` for details.
214
+ process_exception: When reconnecting automatically, tell whether an
215
+ error is transient or fatal. The default behavior is defined by
216
+ :func:`process_exception`. Refer to its documentation for details.
156
217
open_timeout: Timeout for opening the connection in seconds.
157
218
:obj:`None` disables the timeout.
158
219
ping_interval: Interval between keepalive pings in seconds.
@@ -219,6 +280,7 @@ def __init__(
219
280
additional_headers : HeadersLike | None = None ,
220
281
user_agent_header : str | None = USER_AGENT ,
221
282
compression : str | None = "deflate" ,
283
+ process_exception : Callable [[Exception ], Exception | None ] = process_exception ,
222
284
# Timeouts
223
285
open_timeout : float | None = 10 ,
224
286
ping_interval : float | None = 20 ,
@@ -281,19 +343,26 @@ def factory() -> ClientConnection:
281
343
282
344
loop = asyncio .get_running_loop ()
283
345
if kwargs .pop ("unix" , False ):
284
- self .create_connection = loop .create_unix_connection (factory , ** kwargs )
346
+ self .create_connection = functools .partial (
347
+ loop .create_unix_connection , factory , ** kwargs
348
+ )
285
349
else :
286
350
if kwargs .get ("sock" ) is None :
287
351
kwargs .setdefault ("host" , wsuri .host )
288
352
kwargs .setdefault ("port" , wsuri .port )
289
- self .create_connection = loop .create_connection (factory , ** kwargs )
353
+ self .create_connection = functools .partial (
354
+ loop .create_connection , factory , ** kwargs
355
+ )
290
356
291
357
self .handshake_args = (
292
358
additional_headers ,
293
359
user_agent_header ,
294
360
)
295
-
361
+ self . process_exception = process_exception
296
362
self .open_timeout = open_timeout
363
+ if logger is None :
364
+ logger = logging .getLogger ("websockets.client" )
365
+ self .logger = logger
297
366
298
367
# ... = await connect(...)
299
368
@@ -304,7 +373,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
304
373
async def __await_impl__ (self ) -> ClientConnection :
305
374
try :
306
375
async with asyncio_timeout (self .open_timeout ):
307
- _transport , self .connection = await self .create_connection
376
+ _transport , self .connection = await self .create_connection ()
308
377
try :
309
378
await self .connection .handshake (* self .handshake_args )
310
379
except (Exception , asyncio .CancelledError ):
@@ -333,6 +402,48 @@ async def __aexit__(
333
402
) -> None :
334
403
await self .connection .close ()
335
404
405
+ # async for ... in connect(...):
406
+
407
+ async def __aiter__ (self ) -> AsyncIterator [ClientConnection ]:
408
+ delays : Generator [float , None , None ] | None = None
409
+ while True :
410
+ try :
411
+ async with self as protocol :
412
+ yield protocol
413
+ except Exception as exc :
414
+ # Determine whether the exception is retryable or fatal.
415
+ # The API of process_exception is "return an exception or None";
416
+ # "raise an exception" is also supported because it's a frequent
417
+ # mistake. It isn't documented in order to keep the API simple.
418
+ try :
419
+ new_exc = self .process_exception (exc )
420
+ except Exception as raised_exc :
421
+ new_exc = raised_exc
422
+
423
+ # The connection failed with a fatal error.
424
+ # Raise the exception and exit the loop.
425
+ if new_exc is exc :
426
+ raise
427
+ if new_exc is not None :
428
+ raise new_exc from exc
429
+
430
+ # The connection failed with a retryable error.
431
+ # Start or continue backoff and reconnect.
432
+ if delays is None :
433
+ delays = backoff ()
434
+ delay = next (delays )
435
+ self .logger .info (
436
+ "! connect failed; reconnecting in %.1f seconds" ,
437
+ delay ,
438
+ exc_info = True ,
439
+ )
440
+ await asyncio .sleep (delay )
441
+ continue
442
+
443
+ else :
444
+ # The connection succeeded. Reset backoff.
445
+ delays = None
446
+
336
447
337
448
def unix_connect (
338
449
path : str | None = None ,
0 commit comments