Skip to content

Commit 8d6d355

Browse files
authored
Merge pull request #67 from HyperionGray/use_trio_channels
Upgrade to trio channels (fixes #41)
2 parents facdf08 + 97410c7 commit 8d6d355

File tree

4 files changed

+20
-238
lines changed

4 files changed

+20
-238
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
'async_generator',
4040
'attrs>=18.2',
4141
'ipaddress',
42-
'trio>=0.8',
42+
'trio>=0.9,<0.10.0',
4343
'wsaccel',
4444
'wsproto>=0.12.0',
4545
'yarl'

tests/test_connection.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,23 @@ async def echo_conn_handler(conn):
6464

6565
@attr.s(hash=False, cmp=False)
6666
class MemoryListener(trio.abc.Listener):
67-
''' This class is copied from trio's own test suite. '''
6867
closed = attr.ib(default=False)
69-
accepted_streams = attr.ib(default=attr.Factory(list))
70-
queued_streams = attr.ib(default=attr.Factory(lambda: trio.Queue(1)))
68+
accepted_streams = attr.ib(factory=list)
69+
queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1)))
7170
accept_hook = attr.ib(default=None)
7271

7372
async def connect(self):
7473
assert not self.closed
75-
client, server = trio.testing.memory_stream_pair()
76-
await self.queued_streams.put(server)
74+
client, server = memory_stream_pair()
75+
await self.queued_streams[0].send(server)
7776
return client
7877

7978
async def accept(self):
8079
await trio.hazmat.checkpoint()
8180
assert not self.closed
8281
if self.accept_hook is not None:
8382
await self.accept_hook()
84-
stream = await self.queued_streams.get()
83+
stream = await self.queued_streams[1].receive()
8584
self.accepted_streams.append(stream)
8685
return stream
8786

trio_websocket/__init__.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import wsproto.frame_protocol as wsframeproto
1717
from yarl import URL
1818

19-
from ._channel import open_channel, EndOfChannel
2019
from .version import __version__
2120

2221
RECEIVE_BYTES = 4096
@@ -452,7 +451,7 @@ def __init__(self, stream, wsproto, *, path=None):
452451
self._reader_running = True
453452
self._path = path
454453
self._subprotocol = None
455-
self._put_channel, self._get_channel = open_channel(0)
454+
self._send_channel, self._recv_channel = trio.open_memory_channel(0)
456455
self._pings = OrderedDict()
457456
# Set when the server has received a connection request event. This
458457
# future is never set on client connections.
@@ -541,8 +540,8 @@ async def get_message(self):
541540
if self._close_reason:
542541
raise ConnectionClosed(self._close_reason)
543542
try:
544-
message = await self._get_channel.get()
545-
except EndOfChannel:
543+
message = await self._recv_channel.receive()
544+
except trio.EndOfChannel:
546545
raise ConnectionClosed(self._close_reason) from None
547546
return message
548547

@@ -601,7 +600,7 @@ async def send_message(self, message):
601600
self._wsproto.send_data(message)
602601
await self._write_pending()
603602

604-
def _abort_web_socket(self):
603+
async def _abort_web_socket(self):
605604
'''
606605
If a stream is closed outside of this class, e.g. due to network
607606
conditions or because some other code closed our stream object, then we
@@ -612,7 +611,7 @@ def _abort_web_socket(self):
612611
if not self._wsproto.closed:
613612
self._wsproto.close(close_reason)
614613
if self._close_reason is None:
615-
self._close_web_socket(close_reason)
614+
await self._close_web_socket(close_reason)
616615
self._reader_running = False
617616
# We didn't really handshake, but we want any task waiting on this event
618617
# (e.g. self.aclose()) to resume.
@@ -643,7 +642,7 @@ async def _close_stream(self):
643642
# This means the TCP connection is already dead.
644643
pass
645644

646-
def _close_web_socket(self, code, reason=None):
645+
async def _close_web_socket(self, code, reason=None):
647646
'''
648647
Mark the WebSocket as closed. Close the message channel so that if any
649648
tasks are suspended in get_message(), they will wake up with a
@@ -652,7 +651,7 @@ def _close_web_socket(self, code, reason=None):
652651
self._close_reason = CloseReason(code, reason)
653652
exc = ConnectionClosed(self._close_reason)
654653
logger.debug('conn#%d websocket closed %r', self._id, exc)
655-
self._put_channel.close()
654+
await self._send_channel.aclose()
656655

657656
async def _get_request(self):
658657
'''
@@ -700,7 +699,7 @@ async def _handle_connection_closed_event(self, event):
700699
:param event:
701700
'''
702701
await self._write_pending()
703-
self._close_web_socket(event.code, event.reason or None)
702+
await self._close_web_socket(event.code, event.reason or None)
704703
self._close_handshake.set()
705704

706705
async def _handle_connection_failed_event(self, event):
@@ -710,7 +709,7 @@ async def _handle_connection_failed_event(self, event):
710709
:param event:
711710
'''
712711
await self._write_pending()
713-
self._close_web_socket(event.code, event.reason or None)
712+
await self._close_web_socket(event.code, event.reason or None)
714713
await self._close_stream()
715714
self._open_handshake.set()
716715
self._close_handshake.set()
@@ -723,7 +722,7 @@ async def _handle_bytes_received_event(self, event):
723722
'''
724723
self._bytes_message += event.data
725724
if event.message_finished:
726-
await self._put_channel.put(self._bytes_message)
725+
await self._send_channel.send(self._bytes_message)
727726
self._bytes_message = b''
728727

729728
async def _handle_text_received_event(self, event):
@@ -734,7 +733,7 @@ async def _handle_text_received_event(self, event):
734733
'''
735734
self._str_message += event.data
736735
if event.message_finished:
737-
await self._put_channel.put(self._str_message)
736+
await self._send_channel.send(self._str_message)
738737
self._str_message = ''
739738

740739
async def _handle_ping_received_event(self, event):
@@ -812,15 +811,15 @@ async def _reader_task(self):
812811
try:
813812
data = await self._stream.receive_some(RECEIVE_BYTES)
814813
except (trio.BrokenResourceError, trio.ClosedResourceError):
815-
self._abort_web_socket()
814+
await self._abort_web_socket()
816815
break
817816
if len(data) == 0:
818817
logger.debug('conn#%d received zero bytes (connection closed)',
819818
self._id)
820819
# If TCP closed before WebSocket, then record it as an abnormal
821820
# closure.
822821
if not self._wsproto.closed:
823-
self._abort_web_socket()
822+
await self._abort_web_socket()
824823
break
825824
else:
826825
logger.debug('conn#%d received %d bytes', self._id, len(data))
@@ -839,7 +838,7 @@ async def _write_pending(self):
839838
try:
840839
await self._stream.send_all(data)
841840
except (trio.BrokenResourceError, trio.ClosedResourceError):
842-
self._abort_web_socket()
841+
await self._abort_web_socket()
843842
raise ConnectionClosed(self._close_reason) from None
844843
else:
845844
logger.debug('conn#%d no pending data to send', self._id)

trio_websocket/_channel.py

Lines changed: 0 additions & 216 deletions
This file was deleted.

0 commit comments

Comments
 (0)