Skip to content

Commit 1d12c64

Browse files
committed
Align the asyncio and sync client and server modules.
1 parent eba390f commit 1d12c64

File tree

5 files changed

+128
-99
lines changed

5 files changed

+128
-99
lines changed

src/websockets/asyncio/client.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ class connect:
216216
compression: The "permessage-deflate" extension is enabled by default.
217217
Set ``compression`` to :obj:`None` to disable it. See the
218218
:doc:`compression guide <../../topics/compression>` for details.
219-
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
220-
to the handshake request.
219+
additional_headers: Arbitrary HTTP headers to add to the handshake
220+
request.
221221
user_agent_header: Value of the ``User-Agent`` request header.
222222
It defaults to ``"Python/x.y.z websockets/X.Y"``.
223223
Setting it to :obj:`None` removes the header.
@@ -328,6 +328,9 @@ def __init__(
328328
**kwargs: Any,
329329
) -> None:
330330
self.uri = uri
331+
self.ws_uri = parse_uri(uri)
332+
if not self.ws_uri.secure and kwargs.get("ssl") is not None:
333+
raise ValueError("ssl argument is incompatible with a ws:// URI")
331334

332335
if subprotocols is not None:
333336
validate_subprotocols(subprotocols)
@@ -343,7 +346,7 @@ def __init__(
343346
if create_connection is None:
344347
create_connection = ClientConnection
345348

346-
def protocol_factory(uri: WebSocketURI) -> ClientConnection:
349+
def factory(uri: WebSocketURI) -> ClientConnection:
347350
# This is a protocol in the Sans-I/O implementation of websockets.
348351
protocol = ClientProtocol(
349352
uri,
@@ -365,40 +368,35 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
365368
return connection
366369

367370
self.proxy = proxy
368-
self.protocol_factory = protocol_factory
371+
self.factory = factory
369372
self.additional_headers = additional_headers
370373
self.user_agent_header = user_agent_header
371374
self.process_exception = process_exception
372375
self.open_timeout = open_timeout
373376
self.logger = logger
374-
self.connection_kwargs = kwargs
377+
self.create_connection_kwargs = kwargs
375378

376-
async def create_connection(self) -> ClientConnection:
377-
"""Create TCP or Unix connection."""
379+
async def open_tcp_connection(self) -> ClientConnection:
380+
"""Create TCP or Unix connection to the server, possibly through a proxy."""
378381
loop = asyncio.get_running_loop()
379-
kwargs = self.connection_kwargs.copy()
380-
381-
ws_uri = parse_uri(self.uri)
382+
kwargs = self.create_connection_kwargs.copy()
382383

383384
proxy = self.proxy
384385
if kwargs.get("unix", False):
385386
proxy = None
386387
if kwargs.get("sock") is not None:
387388
proxy = None
388389
if proxy is True:
389-
proxy = get_proxy(ws_uri)
390+
proxy = get_proxy(self.ws_uri)
390391

391392
def factory() -> ClientConnection:
392-
return self.protocol_factory(ws_uri)
393+
return self.factory(self.ws_uri)
393394

394-
if ws_uri.secure:
395+
if self.ws_uri.secure:
395396
kwargs.setdefault("ssl", True)
396-
kwargs.setdefault("server_hostname", ws_uri.host)
397397
if kwargs.get("ssl") is None:
398398
raise ValueError("ssl=None is incompatible with a wss:// URI")
399-
else:
400-
if kwargs.get("ssl") is not None:
401-
raise ValueError("ssl argument is incompatible with a ws:// URI")
399+
kwargs.setdefault("server_hostname", self.ws_uri.host)
402400

403401
if kwargs.pop("unix", False):
404402
_, connection = await loop.create_unix_connection(factory, **kwargs)
@@ -408,7 +406,7 @@ def factory() -> ClientConnection:
408406
# Connect to the server through the proxy.
409407
sock = await connect_socks_proxy(
410408
proxy_parsed,
411-
ws_uri,
409+
self.ws_uri,
412410
local_addr=kwargs.pop("local_addr", None),
413411
)
414412
# Initialize WebSocket connection via the proxy.
@@ -442,7 +440,7 @@ def factory() -> ClientConnection:
442440
# Connect to the server through the proxy.
443441
transport = await connect_http_proxy(
444442
proxy_parsed,
445-
ws_uri,
443+
self.ws_uri,
446444
user_agent_header=self.user_agent_header,
447445
**proxy_kwargs,
448446
)
@@ -459,18 +457,18 @@ def factory() -> ClientConnection:
459457
assert new_transport is not None # help mypy
460458
transport = new_transport
461459
connection.connection_made(transport)
462-
else:
463-
raise AssertionError("unsupported proxy")
460+
else: # pragma: no cover
461+
raise NotImplementedError(f"unsupported proxy: {proxy}")
464462
else:
465463
# Connect to the server directly.
466464
if kwargs.get("sock") is None:
467-
kwargs.setdefault("host", ws_uri.host)
468-
kwargs.setdefault("port", ws_uri.port)
465+
kwargs.setdefault("host", self.ws_uri.host)
466+
kwargs.setdefault("port", self.ws_uri.port)
469467
# Initialize WebSocket connection.
470468
_, connection = await loop.create_connection(factory, **kwargs)
471469
return connection
472470

473-
def process_redirect(self, exc: Exception) -> Exception | str:
471+
def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]:
474472
"""
475473
Determine whether a connection error is a redirect that can be followed.
476474
@@ -492,12 +490,12 @@ def process_redirect(self, exc: Exception) -> Exception | str:
492490
):
493491
return exc
494492

495-
old_ws_uri = parse_uri(self.uri)
493+
old_ws_uri = self.ws_uri
496494
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
497495
new_ws_uri = parse_uri(new_uri)
498496

499497
# If connect() received a socket, it is closed and cannot be reused.
500-
if self.connection_kwargs.get("sock") is not None:
498+
if self.create_connection_kwargs.get("sock") is not None:
501499
return ValueError(
502500
f"cannot follow redirect to {new_uri} with a preexisting socket"
503501
)
@@ -513,23 +511,23 @@ def process_redirect(self, exc: Exception) -> Exception | str:
513511
or old_ws_uri.port != new_ws_uri.port
514512
):
515513
# Cross-origin redirects on Unix sockets don't quite make sense.
516-
if self.connection_kwargs.get("unix", False):
514+
if self.create_connection_kwargs.get("unix", False):
517515
return ValueError(
518516
f"cannot follow cross-origin redirect to {new_uri} "
519517
f"with a Unix socket"
520518
)
521519

522520
# Cross-origin redirects when host and port are overridden are ill-defined.
523521
if (
524-
self.connection_kwargs.get("host") is not None
525-
or self.connection_kwargs.get("port") is not None
522+
self.create_connection_kwargs.get("host") is not None
523+
or self.create_connection_kwargs.get("port") is not None
526524
):
527525
return ValueError(
528526
f"cannot follow cross-origin redirect to {new_uri} "
529527
f"with an explicit host or port"
530528
)
531529

532-
return new_uri
530+
return new_uri, new_ws_uri
533531

534532
# ... = await connect(...)
535533

@@ -541,14 +539,14 @@ async def __await_impl__(self) -> ClientConnection:
541539
try:
542540
async with asyncio_timeout(self.open_timeout):
543541
for _ in range(MAX_REDIRECTS):
544-
self.connection = await self.create_connection()
542+
connection = await self.open_tcp_connection()
545543
try:
546-
await self.connection.handshake(
544+
await connection.handshake(
547545
self.additional_headers,
548546
self.user_agent_header,
549547
)
550548
except asyncio.CancelledError:
551-
self.connection.transport.abort()
549+
connection.transport.abort()
552550
raise
553551
except Exception as exc:
554552
# Always close the connection even though keep-alive is
@@ -557,22 +555,23 @@ async def __await_impl__(self) -> ClientConnection:
557555
# protocol. In the current design of connect(), there is
558556
# no easy way to reuse the network connection that works
559557
# in every case nor to reinitialize the protocol.
560-
self.connection.transport.abort()
558+
connection.transport.abort()
561559

562-
uri_or_exc = self.process_redirect(exc)
563-
# Response is a valid redirect; follow it.
564-
if isinstance(uri_or_exc, str):
565-
self.uri = uri_or_exc
566-
continue
560+
exc_or_uri = self.process_redirect(exc)
567561
# Response isn't a valid redirect; raise the exception.
568-
if uri_or_exc is exc:
569-
raise
562+
if isinstance(exc_or_uri, Exception):
563+
if exc_or_uri is exc:
564+
raise
565+
else:
566+
raise exc_or_uri from exc
567+
# Response is a valid redirect; follow it.
570568
else:
571-
raise uri_or_exc from exc
569+
self.uri, self.ws_uri = exc_or_uri
570+
continue
572571

573572
else:
574-
self.connection.start_keepalive()
575-
return self.connection
573+
connection.start_keepalive()
574+
return connection
576575
else:
577576
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
578577

@@ -587,24 +586,30 @@ async def __await_impl__(self) -> ClientConnection:
587586
# async with connect(...) as ...: ...
588587

589588
async def __aenter__(self) -> ClientConnection:
590-
return await self
589+
if hasattr(self, "connection"):
590+
raise RuntimeError("connect() isn't reentrant")
591+
self.connection = await self
592+
return self.connection
591593

592594
async def __aexit__(
593595
self,
594596
exc_type: type[BaseException] | None,
595597
exc_value: BaseException | None,
596598
traceback: TracebackType | None,
597599
) -> None:
598-
await self.connection.close()
600+
try:
601+
await self.connection.close()
602+
finally:
603+
del self.connection
599604

600605
# async for ... in connect(...):
601606

602607
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
603608
delays: Generator[float] | None = None
604609
while True:
605610
try:
606-
async with self as protocol:
607-
yield protocol
611+
async with self as connection:
612+
yield connection
608613
except Exception as exc:
609614
# Determine whether the exception is retryable or fatal.
610615
# The API of process_exception is "return an exception or None";
@@ -633,7 +638,6 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]:
633638
traceback.format_exception_only(exc)[0].strip(),
634639
)
635640
await asyncio.sleep(delay)
636-
continue
637641

638642
else:
639643
# The connection succeeded. Reset backoff.
@@ -777,8 +781,7 @@ def eof_received(self) -> None:
777781

778782
def connection_lost(self, exc: Exception | None) -> None:
779783
self.reader.feed_eof()
780-
if exc is not None:
781-
self.response.set_exception(exc)
784+
self.run_parser()
782785

783786

784787
async def connect_http_proxy(
@@ -797,8 +800,8 @@ async def connect_http_proxy(
797800
try:
798801
# This raises exceptions if the connection to the proxy fails.
799802
await protocol.response
800-
except Exception:
801-
transport.close()
803+
except (asyncio.CancelledError, Exception):
804+
transport.abort()
802805
raise
803806

804807
return transport

src/websockets/asyncio/server.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ async def handshake(
169169
assert isinstance(response, Response) # help mypy
170170
self.response = response
171171

172-
if server_header:
172+
if server_header is not None:
173173
self.response.headers["Server"] = server_header
174174

175175
response = None
@@ -231,12 +231,9 @@ class Server:
231231
232232
This class mirrors the API of :class:`asyncio.Server`.
233233
234-
It keeps track of WebSocket connections in order to close them properly
235-
when shutting down.
236-
237234
Args:
238235
handler: Connection handler. It receives the WebSocket connection,
239-
which is a :class:`ServerConnection`, in argument.
236+
which is a :class:`ServerConnection`.
240237
process_request: Intercept the request during the opening handshake.
241238
Return an HTTP response to force the response. Return :obj:`None` to
242239
continue normally. When you force an HTTP 101 Continue response, the
@@ -310,7 +307,11 @@ def connections(self) -> set[ServerConnection]:
310307
It can be useful in combination with :func:`~broadcast`.
311308
312309
"""
313-
return {connection for connection in self.handlers if connection.state is OPEN}
310+
return {
311+
connection
312+
for connection in self.handlers
313+
if connection.protocol.state is OPEN
314+
}
314315

315316
def wrap(self, server: asyncio.Server) -> None:
316317
"""
@@ -351,6 +352,8 @@ async def conn_handler(self, connection: ServerConnection) -> None:
351352
352353
"""
353354
try:
355+
# Apply open_timeout to the WebSocket handshake.
356+
# Use ssl_handshake_timeout for the TLS handshake.
354357
async with asyncio_timeout(self.open_timeout):
355358
try:
356359
await connection.handshake(
@@ -425,7 +428,7 @@ def close(
425428
``code`` and ``reason`` can be customized, for example to use code
426429
1012 (service restart).
427430
428-
* Wait until all connection handlers terminate.
431+
* Wait until all connection handlers have returned.
429432
430433
:meth:`close` is idempotent.
431434
@@ -452,22 +455,20 @@ async def _close(
452455
self.logger.info("server closing")
453456

454457
# Stop accepting new connections.
458+
# Reject OPENING connections with HTTP 503 -- see handshake().
455459
self.server.close()
456460

457461
# Wait until all accepted connections reach connection_made() and call
458462
# register(). See https://github.com/python/cpython/issues/79033 for
459463
# details. This workaround can be removed when dropping Python < 3.11.
460464
await asyncio.sleep(0)
461465

462-
# After server.close(), handshake() closes OPENING connections with an
463-
# HTTP 503 error.
464-
466+
# Close OPEN connections.
465467
if close_connections:
466-
# Close OPEN connections with code 1001 by default.
467468
close_tasks = [
468469
asyncio.create_task(connection.close(code, reason))
469470
for connection in self.handlers
470-
if connection.protocol.state is not CONNECTING
471+
if connection.protocol.state is OPEN
471472
]
472473
# asyncio.wait doesn't accept an empty first argument.
473474
if close_tasks:
@@ -476,7 +477,7 @@ async def _close(
476477
# Wait until all TCP connections are closed.
477478
await self.server.wait_closed()
478479

479-
# Wait until all connection handlers terminate.
480+
# Wait until all connection handlers have returned.
480481
# asyncio.wait doesn't accept an empty first argument.
481482
if self.handlers:
482483
await asyncio.wait(self.handlers.values())
@@ -590,18 +591,18 @@ class serve:
590591
591592
This coroutine returns a :class:`Server` whose API mirrors
592593
:class:`asyncio.Server`. Treat it as an asynchronous context manager to
593-
ensure that the server will be closed::
594+
ensure that the server will be closed gracefully::
594595
595596
from websockets.asyncio.server import serve
596597
597-
def handler(websocket):
598+
async def handler(websocket):
598599
...
599600
600-
# set this future to exit the server
601-
stop = asyncio.get_running_loop().create_future()
601+
# set this event to exit the server
602+
stop = asyncio.Event()
602603
603604
async with serve(handler, host, port):
604-
await stop
605+
await stop.wait()
605606
606607
Alternatively, call :meth:`~Server.serve_forever` to serve requests and
607608
cancel it to stop the server::

0 commit comments

Comments
 (0)