|
| 1 | +from collections import OrderedDict |
| 2 | +from functools import partial |
1 | 3 | import itertools
|
2 | 4 | import logging
|
| 5 | +import random |
3 | 6 | import ssl
|
4 |
| -from functools import partial |
| 7 | +import struct |
5 | 8 |
|
6 | 9 | from async_generator import async_generator, yield_, asynccontextmanager
|
7 | 10 | import attr
|
@@ -320,6 +323,7 @@ def __init__(self, stream, wsproto, path=None):
|
320 | 323 | self._reader_running = True
|
321 | 324 | self._path = path
|
322 | 325 | self._put_channel, self._get_channel = open_channel(0)
|
| 326 | + self._pings = OrderedDict() |
323 | 327 | # Set once the WebSocket open handshake takes place, i.e.
|
324 | 328 | # ConnectionRequested for server or ConnectedEstablished for client.
|
325 | 329 | self._open_handshake = trio.Event()
|
@@ -398,13 +402,37 @@ async def get_message(self):
|
398 | 402 | raise ConnectionClosed(self._close_reason) from None
|
399 | 403 | return message
|
400 | 404 |
|
401 |
| - async def ping(self, payload): |
| 405 | + async def ping(self, payload=None): |
402 | 406 | '''
|
403 |
| - Send WebSocket ping to peer. |
| 407 | + Send WebSocket ping to peer and wait for a correspoding pong. |
| 408 | +
|
| 409 | + Each ping is matched to its expected pong by its payload value. An |
| 410 | + exception is raised if you call ping with a ``payload`` value equal to |
| 411 | + an existing in-flight ping. If the remote endpoint recieves multiple |
| 412 | + pings, it is allowed to send a single pong. Therefore, the order of |
| 413 | + calls to ``ping()`` is tracked, and a pong will wake up its |
| 414 | + corresponding ping _as well as any earlier pings_. |
| 415 | +
|
| 416 | + :param payload: The payload to send. If ``None`` then a random value is |
| 417 | + created. |
| 418 | + :type payload: str, bytes, or None |
| 419 | + :raises ConnectionClosed: if connection is closed |
| 420 | + ''' |
| 421 | + if self._close_reason: |
| 422 | + raise ConnectionClosed(self._close_reason) |
| 423 | + if payload in self._pings: |
| 424 | + raise Exception('Payload value {} is already in flight.'. |
| 425 | + format(payload)) |
| 426 | + if payload is None: |
| 427 | + payload = struct.pack('!I', random.getrandbits(32)) |
| 428 | + self._pings[payload] = trio.Event() |
| 429 | + self._wsproto.ping(payload) |
| 430 | + await self._write_pending() |
| 431 | + await self._pings[payload].wait() |
404 | 432 |
|
405 |
| - Does not wait for pong reply. (Is this the right behavior? This may |
406 |
| - change in the future.) Raises ``ConnectionClosed`` if the connection is |
407 |
| - closed. |
| 433 | + async def pong(self, payload=None): |
| 434 | + ''' |
| 435 | + Send an unsolicted pong. |
408 | 436 |
|
409 | 437 | :param payload: str or bytes payloads
|
410 | 438 | :raises ConnectionClosed: if connection is closed
|
@@ -537,18 +565,37 @@ async def _handle_ping_received_event(self, event):
|
537 | 565 |
|
538 | 566 | :param event:
|
539 | 567 | '''
|
| 568 | + logger.debug('conn#%d ping %r', self._id, event.payload) |
540 | 569 | await self._write_pending()
|
541 | 570 |
|
542 | 571 | async def _handle_pong_received_event(self, event):
|
543 | 572 | '''
|
544 | 573 | Handle a PongReceived event.
|
545 | 574 |
|
546 |
| - Currently we don't do anything special for a Pong frame, but this may |
547 |
| - change in the future. This handler is here as a placeholder. |
| 575 | + When a pong is received, check if we have any ping requests waiting for |
| 576 | + this pong response. If the remote endpoint skipped any earlier pings, |
| 577 | + then we wake up those skipped pings, too. |
| 578 | +
|
| 579 | + This function is async even though it never awaits, because the other |
| 580 | + event handlers are async, too, and event dispatch would be more |
| 581 | + complicated if some handlers were sync. |
548 | 582 |
|
549 | 583 | :param event:
|
550 | 584 | '''
|
551 |
| - logger.debug('conn#%d pong %r', self._id, event.payload) |
| 585 | + payload = bytes(event.payload) |
| 586 | + try: |
| 587 | + event = self._pings[payload] |
| 588 | + except KeyError: |
| 589 | + # We received a pong that doesn't match any in-flight pongs. Nothing |
| 590 | + # we can do with it, so ignore it. |
| 591 | + return |
| 592 | + key, event = self._pings.popitem(0) |
| 593 | + while key != payload: |
| 594 | + logger.debug('conn#%d pong [skipped] %r', self._id, key) |
| 595 | + event.set() |
| 596 | + key, event = self._pings.popitem(0) |
| 597 | + logger.debug('conn#%d pong %r', self._id, key) |
| 598 | + event.set() |
552 | 599 |
|
553 | 600 | async def _reader_task(self):
|
554 | 601 | ''' A background task that reads network data and generates events. '''
|
|
0 commit comments