@@ -216,8 +216,8 @@ class connect:
216216 compression: The "permessage-deflate" extension is enabled by default.
217217 Set ``compression`` to :obj:`None` to disable it. See the
218218 :doc:`compression guide <../../topics/compression>` for details.
219- additional_headers (HeadersLike | None) : Arbitrary HTTP headers to add
220- to the handshake request.
219+ additional_headers: Arbitrary HTTP headers to add to the handshake
220+ request.
221221 user_agent_header: Value of the ``User-Agent`` request header.
222222 It defaults to ``"Python/x.y.z websockets/X.Y"``.
223223 Setting it to :obj:`None` removes the header.
@@ -328,6 +328,9 @@ def __init__(
328328 ** kwargs : Any ,
329329 ) -> None :
330330 self .uri = uri
331+ self .ws_uri = parse_uri (uri )
332+ if not self .ws_uri .secure and kwargs .get ("ssl" ) is not None :
333+ raise ValueError ("ssl argument is incompatible with a ws:// URI" )
331334
332335 if subprotocols is not None :
333336 validate_subprotocols (subprotocols )
@@ -343,7 +346,7 @@ def __init__(
343346 if create_connection is None :
344347 create_connection = ClientConnection
345348
346- def protocol_factory (uri : WebSocketURI ) -> ClientConnection :
349+ def factory (uri : WebSocketURI ) -> ClientConnection :
347350 # This is a protocol in the Sans-I/O implementation of websockets.
348351 protocol = ClientProtocol (
349352 uri ,
@@ -365,40 +368,35 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
365368 return connection
366369
367370 self .proxy = proxy
368- self .protocol_factory = protocol_factory
371+ self .factory = factory
369372 self .additional_headers = additional_headers
370373 self .user_agent_header = user_agent_header
371374 self .process_exception = process_exception
372375 self .open_timeout = open_timeout
373376 self .logger = logger
374- self .connection_kwargs = kwargs
377+ self .create_connection_kwargs = kwargs
375378
376- async def create_connection (self ) -> ClientConnection :
377- """Create TCP or Unix connection."""
379+ async def open_tcp_connection (self ) -> ClientConnection :
380+ """Create TCP or Unix connection to the server, possibly through a proxy ."""
378381 loop = asyncio .get_running_loop ()
379- kwargs = self .connection_kwargs .copy ()
380-
381- ws_uri = parse_uri (self .uri )
382+ kwargs = self .create_connection_kwargs .copy ()
382383
383384 proxy = self .proxy
384385 if kwargs .get ("unix" , False ):
385386 proxy = None
386387 if kwargs .get ("sock" ) is not None :
387388 proxy = None
388389 if proxy is True :
389- proxy = get_proxy (ws_uri )
390+ proxy = get_proxy (self . ws_uri )
390391
391392 def factory () -> ClientConnection :
392- return self .protocol_factory ( ws_uri )
393+ return self .factory ( self . ws_uri )
393394
394- if ws_uri .secure :
395+ if self . ws_uri .secure :
395396 kwargs .setdefault ("ssl" , True )
396- kwargs .setdefault ("server_hostname" , ws_uri .host )
397397 if kwargs .get ("ssl" ) is None :
398398 raise ValueError ("ssl=None is incompatible with a wss:// URI" )
399- else :
400- if kwargs .get ("ssl" ) is not None :
401- raise ValueError ("ssl argument is incompatible with a ws:// URI" )
399+ kwargs .setdefault ("server_hostname" , self .ws_uri .host )
402400
403401 if kwargs .pop ("unix" , False ):
404402 _ , connection = await loop .create_unix_connection (factory , ** kwargs )
@@ -408,7 +406,7 @@ def factory() -> ClientConnection:
408406 # Connect to the server through the proxy.
409407 sock = await connect_socks_proxy (
410408 proxy_parsed ,
411- ws_uri ,
409+ self . ws_uri ,
412410 local_addr = kwargs .pop ("local_addr" , None ),
413411 )
414412 # Initialize WebSocket connection via the proxy.
@@ -442,7 +440,7 @@ def factory() -> ClientConnection:
442440 # Connect to the server through the proxy.
443441 transport = await connect_http_proxy (
444442 proxy_parsed ,
445- ws_uri ,
443+ self . ws_uri ,
446444 user_agent_header = self .user_agent_header ,
447445 ** proxy_kwargs ,
448446 )
@@ -459,18 +457,18 @@ def factory() -> ClientConnection:
459457 assert new_transport is not None # help mypy
460458 transport = new_transport
461459 connection .connection_made (transport )
462- else :
463- raise AssertionError ( "unsupported proxy" )
460+ else : # pragma: no cover
461+ raise NotImplementedError ( f "unsupported proxy: { proxy } " )
464462 else :
465463 # Connect to the server directly.
466464 if kwargs .get ("sock" ) is None :
467- kwargs .setdefault ("host" , ws_uri .host )
468- kwargs .setdefault ("port" , ws_uri .port )
465+ kwargs .setdefault ("host" , self . ws_uri .host )
466+ kwargs .setdefault ("port" , self . ws_uri .port )
469467 # Initialize WebSocket connection.
470468 _ , connection = await loop .create_connection (factory , ** kwargs )
471469 return connection
472470
473- def process_redirect (self , exc : Exception ) -> Exception | str :
471+ def process_redirect (self , exc : Exception ) -> Exception | tuple [ str , WebSocketURI ] :
474472 """
475473 Determine whether a connection error is a redirect that can be followed.
476474
@@ -492,12 +490,12 @@ def process_redirect(self, exc: Exception) -> Exception | str:
492490 ):
493491 return exc
494492
495- old_ws_uri = parse_uri ( self .uri )
493+ old_ws_uri = self .ws_uri
496494 new_uri = urllib .parse .urljoin (self .uri , exc .response .headers ["Location" ])
497495 new_ws_uri = parse_uri (new_uri )
498496
499497 # If connect() received a socket, it is closed and cannot be reused.
500- if self .connection_kwargs .get ("sock" ) is not None :
498+ if self .create_connection_kwargs .get ("sock" ) is not None :
501499 return ValueError (
502500 f"cannot follow redirect to { new_uri } with a preexisting socket"
503501 )
@@ -513,23 +511,23 @@ def process_redirect(self, exc: Exception) -> Exception | str:
513511 or old_ws_uri .port != new_ws_uri .port
514512 ):
515513 # Cross-origin redirects on Unix sockets don't quite make sense.
516- if self .connection_kwargs .get ("unix" , False ):
514+ if self .create_connection_kwargs .get ("unix" , False ):
517515 return ValueError (
518516 f"cannot follow cross-origin redirect to { new_uri } "
519517 f"with a Unix socket"
520518 )
521519
522520 # Cross-origin redirects when host and port are overridden are ill-defined.
523521 if (
524- self .connection_kwargs .get ("host" ) is not None
525- or self .connection_kwargs .get ("port" ) is not None
522+ self .create_connection_kwargs .get ("host" ) is not None
523+ or self .create_connection_kwargs .get ("port" ) is not None
526524 ):
527525 return ValueError (
528526 f"cannot follow cross-origin redirect to { new_uri } "
529527 f"with an explicit host or port"
530528 )
531529
532- return new_uri
530+ return new_uri , new_ws_uri
533531
534532 # ... = await connect(...)
535533
@@ -541,14 +539,14 @@ async def __await_impl__(self) -> ClientConnection:
541539 try :
542540 async with asyncio_timeout (self .open_timeout ):
543541 for _ in range (MAX_REDIRECTS ):
544- self . connection = await self .create_connection ()
542+ connection = await self .open_tcp_connection ()
545543 try :
546- await self . connection .handshake (
544+ await connection .handshake (
547545 self .additional_headers ,
548546 self .user_agent_header ,
549547 )
550548 except asyncio .CancelledError :
551- self . connection .transport .abort ()
549+ connection .transport .abort ()
552550 raise
553551 except Exception as exc :
554552 # Always close the connection even though keep-alive is
@@ -557,22 +555,23 @@ async def __await_impl__(self) -> ClientConnection:
557555 # protocol. In the current design of connect(), there is
558556 # no easy way to reuse the network connection that works
559557 # in every case nor to reinitialize the protocol.
560- self . connection .transport .abort ()
558+ connection .transport .abort ()
561559
562- uri_or_exc = self .process_redirect (exc )
563- # Response is a valid redirect; follow it.
564- if isinstance (uri_or_exc , str ):
565- self .uri = uri_or_exc
566- continue
560+ exc_or_uri = self .process_redirect (exc )
567561 # Response isn't a valid redirect; raise the exception.
568- if uri_or_exc is exc :
569- raise
562+ if isinstance (exc_or_uri , Exception ):
563+ if exc_or_uri is exc :
564+ raise
565+ else :
566+ raise exc_or_uri from exc
567+ # Response is a valid redirect; follow it.
570568 else :
571- raise uri_or_exc from exc
569+ self .uri , self .ws_uri = exc_or_uri
570+ continue
572571
573572 else :
574- self . connection .start_keepalive ()
575- return self . connection
573+ connection .start_keepalive ()
574+ return connection
576575 else :
577576 raise SecurityError (f"more than { MAX_REDIRECTS } redirects" )
578577
@@ -587,24 +586,30 @@ async def __await_impl__(self) -> ClientConnection:
587586 # async with connect(...) as ...: ...
588587
589588 async def __aenter__ (self ) -> ClientConnection :
590- return await self
589+ if hasattr (self , "connection" ):
590+ raise RuntimeError ("connect() isn't reentrant" )
591+ self .connection = await self
592+ return self .connection
591593
592594 async def __aexit__ (
593595 self ,
594596 exc_type : type [BaseException ] | None ,
595597 exc_value : BaseException | None ,
596598 traceback : TracebackType | None ,
597599 ) -> None :
598- await self .connection .close ()
600+ try :
601+ await self .connection .close ()
602+ finally :
603+ del self .connection
599604
600605 # async for ... in connect(...):
601606
602607 async def __aiter__ (self ) -> AsyncIterator [ClientConnection ]:
603608 delays : Generator [float ] | None = None
604609 while True :
605610 try :
606- async with self as protocol :
607- yield protocol
611+ async with self as connection :
612+ yield connection
608613 except Exception as exc :
609614 # Determine whether the exception is retryable or fatal.
610615 # The API of process_exception is "return an exception or None";
@@ -633,7 +638,6 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]:
633638 traceback .format_exception_only (exc )[0 ].strip (),
634639 )
635640 await asyncio .sleep (delay )
636- continue
637641
638642 else :
639643 # The connection succeeded. Reset backoff.
@@ -777,8 +781,7 @@ def eof_received(self) -> None:
777781
778782 def connection_lost (self , exc : Exception | None ) -> None :
779783 self .reader .feed_eof ()
780- if exc is not None :
781- self .response .set_exception (exc )
784+ self .run_parser ()
782785
783786
784787async def connect_http_proxy (
@@ -797,8 +800,8 @@ async def connect_http_proxy(
797800 try :
798801 # This raises exceptions if the connection to the proxy fails.
799802 await protocol .response
800- except Exception :
801- transport .close ()
803+ except ( asyncio . CancelledError , Exception ) :
804+ transport .abort ()
802805 raise
803806
804807 return transport
0 commit comments