Skip to content

Commit b6eb11f

Browse files
committed
Merge branch 'improve_ping_v2'
2 parents 51ad398 + 8e13b26 commit b6eb11f

File tree

4 files changed

+206
-12
lines changed

4 files changed

+206
-12
lines changed

README.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,79 @@ trio.run(main)
7878
A longer example is in `examples/server.py`. **See the note above about using
7979
SSL with the example client.**
8080

81+
## Heartbeat recipe
82+
83+
If you wish to keep a connection open for long periods of time but do not need
84+
to send messages frequently, then a heartbeat holds the connection open and also
85+
detects when the connection drops unexpectedly. The following recipe
86+
demonstrates how to implement a connection heartbeat using WebSocket's ping/pong
87+
feature.
88+
89+
```python
90+
async def heartbeat(ws, timeout, interval):
91+
'''
92+
Send periodic pings on WebSocket ``ws``.
93+
94+
Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises
95+
``TooSlowError`` if the timeout is exceeded. If a pong is received, then
96+
wait ``interval`` seconds before sending the next ping.
97+
98+
This function runs until cancelled.
99+
100+
:param ws: A WebSocket to send heartbeat pings on.
101+
:param float timeout: Timeout in seconds.
102+
:param float interval: Interval between receiving pong and sending next
103+
ping, in seconds.
104+
:raises: ``ConnectionClosed`` if ``ws`` is closed.
105+
:raises: ``TooSlowError`` if the timeout expires.
106+
:returns: This function runs until cancelled.
107+
'''
108+
while True:
109+
with trio.fail_after(timeout):
110+
await ws.ping()
111+
await trio.sleep(interval)
112+
113+
async def main():
114+
async with open_websocket_url('ws://localhost/foo') as ws:
115+
async with trio.open_nursery() as nursery:
116+
nursery.start_soon(heartbeat, ws, 5, 1)
117+
# Your application code goes here:
118+
pass
119+
120+
trio.run(main)
121+
```
122+
123+
Note that the `ping()` method waits until it receives a pong frame, so it
124+
ensures that the remote endpoint is still responsive. If the connection is
125+
dropped unexpectedly or takes too long to respond, then `heartbeat()` will raise
126+
an exception that will cancel the nursery. You may wish to implement additional
127+
logic to automatically reconnect.
128+
129+
A heartbeat feature can be enabled in the example client with the
130+
``--heartbeat`` flag.
131+
132+
**Note that the WebSocket RFC does not require a WebSocket to send a pong for each
133+
ping:**
134+
135+
> If an endpoint receives a Ping frame and has not yet sent Pong frame(s) in
136+
> response to previous Ping frame(s), the endpoint MAY elect to send a Pong
137+
> frame for only the most recently processed Ping frame.
138+
139+
Therefore, if you have multiple pings in flight at the same time, you may not
140+
get an equal number of pongs in response. The simplest strategy for dealing with
141+
this is to only have one ping in flight at a time, as seen in the example above.
142+
As an alternative, you can send a `bytes` payload with each ping. The server
143+
will return the payload with the pong:
144+
145+
```python
146+
await ws.ping(b'my payload')
147+
pong == await ws.wait_pong()
148+
assert pong == b'my payload'
149+
```
150+
151+
You may want to embed a nonce or counter in the payload in order to correlate
152+
pong events to the pings you have sent.
153+
81154
## Unit Tests
82155

83156
Unit tests are written in the pytest style. You must install the development

examples/client.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def commands():
3232
def parse_args():
3333
''' Parse command line arguments. '''
3434
parser = argparse.ArgumentParser(description='Example trio-websocket client')
35+
parser.add_argument('--heartbeat', action='store_true',
36+
help='Create a heartbeat task')
3537
parser.add_argument('url', help='WebSocket URL to connect to')
3638
return parser.parse_args()
3739

@@ -53,17 +55,19 @@ async def main(args):
5355
try:
5456
logging.debug('Connecting to WebSocket…')
5557
async with open_websocket_url(args.url, ssl_context) as conn:
56-
await handle_connection(conn)
58+
await handle_connection(conn, args.heartbeat)
5759
except OSError as ose:
5860
logging.error('Connection attempt failed: %s', ose)
5961
return False
6062

6163

62-
async def handle_connection(ws):
64+
async def handle_connection(ws, use_heartbeat):
6365
''' Handle the connection. '''
6466
logging.debug('Connected!')
6567
try:
6668
async with trio.open_nursery() as nursery:
69+
if use_heartbeat:
70+
nursery.start_soon(heartbeat, ws, 1, 15)
6771
nursery.start_soon(get_commands, ws)
6872
nursery.start_soon(get_messages, ws)
6973
except ConnectionClosed as cc:
@@ -72,6 +76,30 @@ async def handle_connection(ws):
7276
print('Closed: {}/{} {}'.format(cc.reason.code, cc.reason.name, reason))
7377

7478

79+
async def heartbeat(ws, timeout, interval):
80+
'''
81+
Send periodic pings on WebSocket ``ws``.
82+
83+
Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises
84+
``TooSlowError`` if the timeout is exceeded. If a pong is received, then
85+
wait ``interval`` seconds before sending the next ping.
86+
87+
This function runs until cancelled.
88+
89+
:param ws: A WebSocket to send heartbeat pings on.
90+
:param float timeout: Timeout in seconds.
91+
:param float interval: Interval between receiving pong and sending next
92+
ping, in seconds.
93+
:raises: ``ConnectionClosed`` if ``ws`` is closed.
94+
:raises: ``TooSlowError`` if the timeout expires.
95+
:returns: This function runs until cancelled.
96+
'''
97+
while True:
98+
with trio.fail_after(timeout):
99+
await ws.ping()
100+
await trio.sleep(interval)
101+
102+
75103
async def get_commands(ws):
76104
''' In a loop: get a command from the user and execute it. '''
77105
while True:

tests/test_connection.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,51 @@ async def test_client_send_and_receive(echo_conn):
217217
assert received_msg == 'This is a test message.'
218218

219219

220+
async def test_client_ping(echo_conn):
221+
async with echo_conn:
222+
await echo_conn.ping(b'A')
223+
with pytest.raises(ConnectionClosed):
224+
await echo_conn.ping(b'B')
225+
226+
227+
async def test_client_ping_two_payloads(echo_conn):
228+
pong_count = 0
229+
async def ping_and_count():
230+
nonlocal pong_count
231+
await echo_conn.ping()
232+
pong_count += 1
233+
async with echo_conn:
234+
async with trio.open_nursery() as nursery:
235+
nursery.start_soon(ping_and_count)
236+
nursery.start_soon(ping_and_count)
237+
assert pong_count == 2
238+
239+
240+
async def test_client_ping_same_payload(echo_conn):
241+
# This test verifies that two tasks can't ping with the same payload at the
242+
# same time. One of them should succeed and the other should get an
243+
# exception.
244+
exc_count = 0
245+
async def ping_and_catch():
246+
nonlocal exc_count
247+
try:
248+
await echo_conn.ping(b'A')
249+
except Exception:
250+
exc_count += 1
251+
async with echo_conn:
252+
async with trio.open_nursery() as nursery:
253+
nursery.start_soon(ping_and_catch)
254+
nursery.start_soon(ping_and_catch)
255+
assert exc_count == 1
256+
257+
258+
async def test_client_pong(echo_conn):
259+
async with echo_conn:
260+
await echo_conn.pong(b'A')
261+
with pytest.raises(ConnectionClosed):
262+
await echo_conn.pong(b'B')
263+
264+
220265
async def test_client_default_close(echo_conn):
221266
async with echo_conn:
222267
assert not echo_conn.is_closed

trio_websocket/__init__.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from collections import OrderedDict
2+
from functools import partial
13
import itertools
24
import logging
5+
import random
36
import ssl
4-
from functools import partial
7+
import struct
58

69
from async_generator import async_generator, yield_, asynccontextmanager
710
import attr
@@ -320,6 +323,7 @@ def __init__(self, stream, wsproto, path=None):
320323
self._reader_running = True
321324
self._path = path
322325
self._put_channel, self._get_channel = open_channel(0)
326+
self._pings = OrderedDict()
323327
# Set once the WebSocket open handshake takes place, i.e.
324328
# ConnectionRequested for server or ConnectedEstablished for client.
325329
self._open_handshake = trio.Event()
@@ -398,13 +402,38 @@ async def get_message(self):
398402
raise ConnectionClosed(self._close_reason) from None
399403
return message
400404

401-
async def ping(self, payload):
405+
async def ping(self, payload=None):
402406
'''
403-
Send WebSocket ping to peer.
407+
Send WebSocket ping to peer and wait for a correspoding pong.
404408
405-
Does not wait for pong reply. (Is this the right behavior? This may
406-
change in the future.) Raises ``ConnectionClosed`` if the connection is
407-
closed.
409+
Each ping is matched to its expected pong by its payload value. An
410+
exception is raised if you call ping with a ``payload`` value equal to
411+
an existing in-flight ping. If the remote endpoint recieves multiple
412+
pings, it is allowed to send a single pong. Therefore, the order of
413+
calls to ``ping()`` is tracked, and a pong will wake up its
414+
corresponding ping _as well as any earlier pings_.
415+
416+
:param payload: The payload to send. If ``None`` then a random value is
417+
created.
418+
:type payload: str, bytes, or None
419+
:raises ConnectionClosed: if connection is closed
420+
'''
421+
if self._close_reason:
422+
raise ConnectionClosed(self._close_reason)
423+
if payload in self._pings:
424+
raise Exception('Payload value {} is already in flight.'.
425+
format(payload))
426+
if payload is None:
427+
payload = struct.pack('!I', random.getrandbits(32))
428+
event = trio.Event()
429+
self._pings[payload] = event
430+
self._wsproto.ping(payload)
431+
await self._write_pending()
432+
await event.wait()
433+
434+
async def pong(self, payload=None):
435+
'''
436+
Send an unsolicted pong.
408437
409438
:param payload: str or bytes payloads
410439
:raises ConnectionClosed: if connection is closed
@@ -537,18 +566,37 @@ async def _handle_ping_received_event(self, event):
537566
538567
:param event:
539568
'''
569+
logger.debug('conn#%d ping %r', self._id, event.payload)
540570
await self._write_pending()
541571

542572
async def _handle_pong_received_event(self, event):
543573
'''
544574
Handle a PongReceived event.
545575
546-
Currently we don't do anything special for a Pong frame, but this may
547-
change in the future. This handler is here as a placeholder.
576+
When a pong is received, check if we have any ping requests waiting for
577+
this pong response. If the remote endpoint skipped any earlier pings,
578+
then we wake up those skipped pings, too.
579+
580+
This function is async even though it never awaits, because the other
581+
event handlers are async, too, and event dispatch would be more
582+
complicated if some handlers were sync.
548583
549584
:param event:
550585
'''
551-
logger.debug('conn#%d pong %r', self._id, event.payload)
586+
payload = bytes(event.payload)
587+
try:
588+
event = self._pings[payload]
589+
except KeyError:
590+
# We received a pong that doesn't match any in-flight pongs. Nothing
591+
# we can do with it, so ignore it.
592+
return
593+
while self._pings:
594+
key, event = self._pings.popitem(0)
595+
skipped = ' [skipped] ' if payload != key else ' '
596+
logger.debug('conn#%d pong%s%r', self._id, skipped, key)
597+
event.set()
598+
if payload == key:
599+
break
552600

553601
async def _reader_task(self):
554602
''' A background task that reads network data and generates events. '''
@@ -577,7 +625,7 @@ async def _reader_task(self):
577625
event_type)
578626
await handler(event)
579627
except KeyError:
580-
logger.warning('Received unknown event type: %s',
628+
logger.warning('Received unknown event type: "%s"',
581629
event_type)
582630

583631
# Get network data.

0 commit comments

Comments
 (0)