Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sanic/server/websockets/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import secrets

from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from contextlib import suppress

from websockets.exceptions import (
ConnectionClosed,
Expand Down Expand Up @@ -552,6 +553,8 @@ async def recv(self, timeout: float | None = None) -> Data | None:
# recv was cancelled
for p in pending:
p.cancel()
with suppress(asyncio.CancelledError):
await p
raise asyncio.CancelledError()
else:
self.recv_cancel.cancel()
Expand All @@ -560,6 +563,8 @@ async def recv(self, timeout: float | None = None) -> Data | None:
# recv was cancelled
if assembler_get:
assembler_get.cancel()
with suppress(asyncio.CancelledError):
await assembler_get
raise
finally:
self.recv_cancel = None
Expand Down Expand Up @@ -612,6 +617,8 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]:
# recv_burst was cancelled
for p in pending:
p.cancel()
with suppress(asyncio.CancelledError):
await p
raise asyncio.CancelledError()
m = done_task.result()
if m is None:
Expand All @@ -629,6 +636,8 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]:
# recv_burst was cancelled
if assembler_get:
assembler_get.cancel()
with suppress(asyncio.CancelledError):
await assembler_get
raise
finally:
self.recv_cancel = None
Expand Down
89 changes: 89 additions & 0 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import re

from asyncio import Event, Queue, TimeoutError
Expand All @@ -9,8 +10,14 @@

from sanic.exceptions import ServerError
from sanic.server.websockets.frame import WebsocketFrameAssembler
from sanic.server.websockets.impl import WebsocketImplProtocol


try:
from websockets.protocol import State
except ImportError:
from websockets.connection import State

try:
from unittest.mock import AsyncMock
except ImportError:
Expand Down Expand Up @@ -238,3 +245,85 @@ async def test_ws_frame_put_skip_ctrl(opcode):
retval = await assembler.put(Frame(opcode, b""))

assert retval is None


def _make_ws_proto():
"""Create a WebsocketImplProtocol with a mock assembler that blocks.

Returns (ws, get_started, get_finished) where get_finished is an
Event that is set when the assembler.get coroutine actually exits
(either normally or via cancellation).
"""
ws_proto = Mock()
ws_proto.state = State.OPEN
ws = WebsocketImplProtocol(ws_proto, ping_interval=None, ping_timeout=None)

get_started = asyncio.Event()
get_finished = asyncio.Event()

async def slow_get(timeout=None):
get_started.set()
try:
await asyncio.sleep(10)
finally:
get_finished.set()

assembler = Mock()
assembler.get = slow_get
ws.assembler = assembler

return ws, get_started, get_finished


async def _assert_recv_awaits_assembler_on_cancel(cancel_fn):
ws, get_started, get_finished = _make_ws_proto()

recv_task = asyncio.create_task(ws.recv(timeout=5))
await get_started.wait()

# Hook into recv_lock.release (called in recv's finally block) to
# check whether the assembler task finished before recv() returned.
finished_before_return = False

original_release = ws.recv_lock.release

def check_on_release():
nonlocal finished_before_return
finished_before_return = get_finished.is_set()
original_release()

ws.recv_lock.release = check_on_release

cancel_fn(recv_task, ws)

try:
await recv_task
except asyncio.CancelledError:
pass

assert finished_before_return, (
"assembler.get() coroutine was still pending when recv() "
"returned — would cause 'Task was destroyed but it is "
"pending' on shutdown"
)


def _cancel_recv_task(recv_task, ws):
recv_task.cancel()


def _cancel_recv_waiter(recv_task, ws):
ws.recv_cancel.cancel()


RECV_CANCEL_CASES = (
pytest.param(_cancel_recv_task, id="task-cancel"),
pytest.param(_cancel_recv_waiter, id="recv-cancel"),
)


@pytest.mark.asyncio
@pytest.mark.parametrize("cancel_fn", RECV_CANCEL_CASES)
async def test_ws_recv_cancel_awaits_assembler_task(cancel_fn):
"""Cancelling recv() should clean up the assembler task."""
await _assert_recv_awaits_assembler_on_cancel(cancel_fn)
Loading