Skip to content

Commit 2bbae1b

Browse files
committed
Implement ping/pong matching (#22)
As suggested in the discussion thread, this commit mimics the ping/pong API in aaugstin's websockets. A call to ping waits for a response with the same payload. (An exception is raised if the payload matches a ping that's already in flight.) A payload is randomly generated if omitted by the caller. Add a new pong() method that can be used for sending an unsolicited pong.
1 parent e971e7b commit 2bbae1b

File tree

2 files changed

+101
-9
lines changed

2 files changed

+101
-9
lines changed

tests/test_connection.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,51 @@ async def test_client_send_and_receive(echo_conn):
207207
assert received_msg == 'This is a test message.'
208208

209209

210+
async def test_client_ping(echo_conn):
211+
async with echo_conn:
212+
await echo_conn.ping(b'A')
213+
with pytest.raises(ConnectionClosed):
214+
await echo_conn.ping(b'B')
215+
216+
217+
async def test_client_ping_two_payloads(echo_conn):
218+
pong_count = 0
219+
async def ping_and_count():
220+
nonlocal pong_count
221+
await echo_conn.ping()
222+
pong_count += 1
223+
async with echo_conn:
224+
async with trio.open_nursery() as nursery:
225+
nursery.start_soon(ping_and_count)
226+
nursery.start_soon(ping_and_count)
227+
assert pong_count == 2
228+
229+
230+
async def test_client_ping_same_payload(echo_conn):
231+
# This test verifies that two tasks can't ping with the same payload at the
232+
# same time. One of them should succeed and the other should get an
233+
# exception.
234+
exc_count = 0
235+
async def ping_and_catch():
236+
nonlocal exc_count
237+
try:
238+
await echo_conn.ping(b'A')
239+
except Exception:
240+
exc_count += 1
241+
async with echo_conn:
242+
async with trio.open_nursery() as nursery:
243+
nursery.start_soon(ping_and_catch)
244+
nursery.start_soon(ping_and_catch)
245+
assert exc_count == 1
246+
247+
248+
async def test_client_pong(echo_conn):
249+
async with echo_conn:
250+
await echo_conn.pong(b'A')
251+
with pytest.raises(ConnectionClosed):
252+
await echo_conn.pong(b'B')
253+
254+
210255
async def test_client_default_close(echo_conn):
211256
async with echo_conn:
212257
assert not echo_conn.is_closed

trio_websocket/__init__.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from collections import OrderedDict
2+
from functools import partial
13
import itertools
24
import logging
5+
import random
36
import ssl
4-
from functools import partial
7+
import struct
58

69
from async_generator import async_generator, yield_, asynccontextmanager
710
import attr
@@ -320,6 +323,7 @@ def __init__(self, stream, wsproto, path=None):
320323
self._reader_running = True
321324
self._path = path
322325
self._put_channel, self._get_channel = open_channel(0)
326+
self._pings = OrderedDict()
323327
# Set once the WebSocket open handshake takes place, i.e.
324328
# ConnectionRequested for server or ConnectedEstablished for client.
325329
self._open_handshake = trio.Event()
@@ -398,13 +402,37 @@ async def get_message(self):
398402
raise ConnectionClosed(self._close_reason) from None
399403
return message
400404

401-
async def ping(self, payload):
405+
async def ping(self, payload=None):
402406
'''
403-
Send WebSocket ping to peer.
407+
Send WebSocket ping to peer and wait for a correspoding pong.
408+
409+
Each ping is matched to its expected pong by its payload value. An
410+
exception is raised if you call ping with a ``payload`` value equal to
411+
an existing in-flight ping. If the remote endpoint recieves multiple
412+
pings, it is allowed to send a single pong. Therefore, the order of
413+
calls to ``ping()`` is tracked, and a pong will wake up its
414+
corresponding ping _as well as any earlier pings_.
415+
416+
:param payload: The payload to send. If ``None`` then a random value is
417+
created.
418+
:type payload: str, bytes, or None
419+
:raises ConnectionClosed: if connection is closed
420+
'''
421+
if self._close_reason:
422+
raise ConnectionClosed(self._close_reason)
423+
if payload in self._pings:
424+
raise Exception('Payload value {} is already in flight.'.
425+
format(payload))
426+
if payload is None:
427+
payload = struct.pack('!I', random.getrandbits(32))
428+
self._pings[payload] = trio.Event()
429+
self._wsproto.ping(payload)
430+
await self._write_pending()
431+
await self._pings[payload].wait()
404432

405-
Does not wait for pong reply. (Is this the right behavior? This may
406-
change in the future.) Raises ``ConnectionClosed`` if the connection is
407-
closed.
433+
async def pong(self, payload=None):
434+
'''
435+
Send an unsolicted pong.
408436
409437
:param payload: str or bytes payloads
410438
:raises ConnectionClosed: if connection is closed
@@ -537,18 +565,37 @@ async def _handle_ping_received_event(self, event):
537565
538566
:param event:
539567
'''
568+
logger.debug('conn#%d ping %r', self._id, event.payload)
540569
await self._write_pending()
541570

542571
async def _handle_pong_received_event(self, event):
543572
'''
544573
Handle a PongReceived event.
545574
546-
Currently we don't do anything special for a Pong frame, but this may
547-
change in the future. This handler is here as a placeholder.
575+
When a pong is received, check if we have any ping requests waiting for
576+
this pong response. If the remote endpoint skipped any earlier pings,
577+
then we wake up those skipped pings, too.
578+
579+
This function is async even though it never awaits, because the other
580+
event handlers are async, too, and event dispatch would be more
581+
complicated if some handlers were sync.
548582
549583
:param event:
550584
'''
551-
logger.debug('conn#%d pong %r', self._id, event.payload)
585+
payload = bytes(event.payload)
586+
try:
587+
event = self._pings[payload]
588+
except KeyError:
589+
# We received a pong that doesn't match any in-flight pongs. Nothing
590+
# we can do with it, so ignore it.
591+
return
592+
key, event = self._pings.popitem(0)
593+
while key != payload:
594+
logger.debug('conn#%d pong [skipped] %r', self._id, key)
595+
event.set()
596+
key, event = self._pings.popitem(0)
597+
logger.debug('conn#%d pong %r', self._id, key)
598+
event.set()
552599

553600
async def _reader_task(self):
554601
''' A background task that reads network data and generates events. '''

0 commit comments

Comments
 (0)