Skip to content

Commit e7706f4

Browse files
authored
Merge pull request #192 from jakkdl/exc_group_cause_context
fix loss of context/cause on exceptions raised inside open_websocket
2 parents f5fd6d7 + b8d1fc7 commit e7706f4

File tree

2 files changed

+104
-29
lines changed

2 files changed

+104
-29
lines changed

tests/test_connection.py

+54-11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
'''
3232
from __future__ import annotations
3333

34+
import copy
3435
from functools import partial, wraps
3536
import re
3637
import ssl
@@ -452,7 +453,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock):
452453
Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup
453454
"""
454455
async def ki_raising_ping_handler(*args, **kwargs) -> None:
455-
print("raising ki")
456456
raise KeyboardInterrupt
457457
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler)
458458
async def handler(request):
@@ -474,27 +474,32 @@ async def handler(request):
474474
async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock):
475475
"""_reader_task._handle_ping_event triggers ValueError.
476476
user code also raises exception.
477-
internal exception is in __cause__ exceptiongroup and user exc is delivered
477+
internal exception is in __context__ exceptiongroup and user exc is delivered
478478
"""
479-
my_value_error = ValueError()
479+
internal_error = ValueError()
480+
internal_error.__context__ = TypeError()
481+
user_error = NameError()
482+
user_error_context = KeyError()
480483
async def raising_ping_event(*args, **kwargs) -> None:
481-
raise my_value_error
484+
raise internal_error
482485

483486
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event)
484487
async def handler(request):
485488
server_ws = await request.accept()
486489
await server_ws.ping(b"a")
487490

488491
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
489-
with pytest.raises(trio.TooSlowError) as exc_info:
492+
with pytest.raises(type(user_error)) as exc_info:
490493
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
491-
with trio.fail_after(1) as cs:
492-
cs.shield = True
493-
await trio.sleep(2)
494+
await trio.lowlevel.checkpoint()
495+
user_error.__context__ = user_error_context
496+
raise user_error
494497

495-
e_cause = exc_info.value.__cause__
496-
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
497-
assert my_value_error in e_cause.exceptions
498+
assert exc_info.value is user_error
499+
e_context = exc_info.value.__context__
500+
assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment
501+
assert internal_error in e_context.exceptions
502+
assert user_error_context in e_context.exceptions
498503

499504
@fail_after(5)
500505
async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock):
@@ -513,6 +518,8 @@ async def handler(request):
513518
server_ws = await request.accept()
514519
await server_ws.ping(b"a")
515520
user_cancelled = None
521+
user_cancelled_cause = None
522+
user_cancelled_context = None
516523

517524
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
518525
with trio.move_on_after(2):
@@ -522,8 +529,13 @@ async def handler(request):
522529
await trio.sleep_forever()
523530
except trio.Cancelled as e:
524531
user_cancelled = e
532+
user_cancelled_cause = e.__cause__
533+
user_cancelled_context = e.__context__
525534
raise
535+
526536
assert exc_info.value is user_cancelled
537+
assert exc_info.value.__cause__ is user_cancelled_cause
538+
assert exc_info.value.__context__ is user_cancelled_context
527539

528540
def _trio_default_non_strict_exception_groups() -> bool:
529541
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
@@ -560,6 +572,24 @@ async def handler(request):
560572
RaisesGroup(ValueError)))).matches(exc.value)
561573

562574

575+
async def test_user_exception_cause(nursery) -> None:
576+
async def handler(request):
577+
await request.accept()
578+
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
579+
e_context = TypeError("foo")
580+
e_primary = ValueError("bar")
581+
e_cause = RuntimeError("zee")
582+
with pytest.raises(ValueError) as exc_info:
583+
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
584+
try:
585+
raise e_context
586+
except TypeError:
587+
raise e_primary from e_cause
588+
e = exc_info.value
589+
assert e is e_primary
590+
assert e.__cause__ is e_cause
591+
assert e.__context__ is e_context
592+
563593
@fail_after(1)
564594
async def test_reject_handshake(nursery):
565595
async def handler(request):
@@ -1176,3 +1206,16 @@ async def server():
11761206
async with trio.open_nursery() as nursery:
11771207
nursery.start_soon(server)
11781208
nursery.start_soon(client)
1209+
1210+
1211+
def test_copy_exceptions():
1212+
# test that exceptions are copy- and pickleable
1213+
copy.copy(HandshakeError())
1214+
copy.copy(ConnectionTimeout())
1215+
copy.copy(DisconnectionTimeout())
1216+
assert copy.copy(ConnectionClosed("foo")).reason == "foo"
1217+
1218+
rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c"))
1219+
assert rej_copy.status_code == 404
1220+
assert rej_copy.headers == (("a", "b"),)
1221+
assert rej_copy.body == b"c"

trio_websocket/_impl.py

+50-18
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import ssl
1212
import struct
1313
import urllib.parse
14-
from typing import Iterable, List, Optional, Union
14+
from typing import Iterable, List, NoReturn, Optional, Union
1515

1616
import outcome
1717
import trio
@@ -151,14 +151,14 @@ async def open_websocket(
151151
# yield to user code. If only one of those raise a non-cancelled exception
152152
# we will raise that non-cancelled exception.
153153
# If we get multiple cancelled, we raise the user's cancelled.
154-
# If both raise exceptions, we raise the user code's exception with the entire
155-
# exception group as the __cause__.
154+
# If both raise exceptions, we raise the user code's exception with __context__
155+
# set to a group containing internal exception(s) + any user exception __context__
156156
# If we somehow get multiple exceptions, but no user exception, then we raise
157157
# TrioWebsocketInternalError.
158158

159159
# If closing the connection fails, then that will be raised as the top
160160
# exception in the last `finally`. If we encountered exceptions in user code
161-
# or in reader task then they will be set as the `__cause__`.
161+
# or in reader task then they will be set as the `__context__`.
162162

163163

164164
async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
@@ -181,10 +181,27 @@ async def _close_connection(connection: WebSocketConnection) -> None:
181181
except trio.TooSlowError:
182182
raise DisconnectionTimeout from None
183183

184+
def _raise(exc: BaseException) -> NoReturn:
185+
"""This helper allows re-raising an exception without __context__ being set."""
186+
# cause does not need special handlng, we simply avoid using `raise .. from ..`
187+
__tracebackhide__ = True
188+
context = exc.__context__
189+
try:
190+
raise exc
191+
finally:
192+
exc.__context__ = context
193+
del exc, context
194+
184195
connection: WebSocketConnection|None=None
185196
close_result: outcome.Maybe[None] | None = None
186197
user_error = None
187198

199+
# Unwrapping exception groups has a lot of pitfalls, one of them stemming from
200+
# the exception we raise also being inside the group that's set as the context.
201+
# This leads to loss of info unless properly handled.
202+
# See https://github.com/python-trio/flake8-async/issues/298
203+
# We therefore avoid having the exceptiongroup included as either cause or context
204+
188205
try:
189206
async with trio.open_nursery() as new_nursery:
190207
result = await outcome.acapture(_open_connection, new_nursery)
@@ -205,7 +222,7 @@ async def _close_connection(connection: WebSocketConnection) -> None:
205222
except _TRIO_EXC_GROUP_TYPE as e:
206223
# user_error, or exception bubbling up from _reader_task
207224
if len(e.exceptions) == 1:
208-
raise e.exceptions[0]
225+
_raise(e.exceptions[0])
209226

210227
# contains at most 1 non-cancelled exceptions
211228
exception_to_raise: BaseException|None = None
@@ -218,25 +235,40 @@ async def _close_connection(connection: WebSocketConnection) -> None:
218235
else:
219236
if exception_to_raise is None:
220237
# all exceptions are cancelled
221-
# prefer raising the one from the user, for traceback reasons
238+
# we reraise the user exception and throw out internal
222239
if user_error is not None:
223-
# no reason to raise from e, just to include a bunch of extra
224-
# cancelleds.
225-
raise user_error # pylint: disable=raise-missing-from
240+
_raise(user_error)
226241
# multiple internal Cancelled is not possible afaik
227-
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
228-
raise exception_to_raise
242+
# but if so we just raise one of them
243+
_raise(e.exceptions[0]) # pragma: no cover
244+
# raise the non-cancelled exception
245+
_raise(exception_to_raise)
229246

230-
# if we have any KeyboardInterrupt in the group, make sure to raise it.
247+
# if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt
248+
# with the group as cause & context
231249
for sub_exc in e.exceptions:
232250
if isinstance(sub_exc, KeyboardInterrupt):
233-
raise sub_exc from e
251+
raise KeyboardInterrupt from e
234252

235253
# Both user code and internal code raised non-cancelled exceptions.
236-
# We "hide" the internal exception(s) in the __cause__ and surface
237-
# the user_error.
254+
# We set the context to be an exception group containing internal exceptions
255+
# and, if not None, `user_error.__context__`
238256
if user_error is not None:
239-
raise user_error from e
257+
exceptions = [subexc for subexc in e.exceptions if subexc is not user_error]
258+
eg_substr = ''
259+
# there's technically loss of info here, with __suppress_context__=True you
260+
# still have original __context__ available, just not printed. But we delete
261+
# it completely because we can't partially suppress the group
262+
if user_error.__context__ is not None and not user_error.__suppress_context__:
263+
exceptions.append(user_error.__context__)
264+
eg_substr = ' and the context for the user exception'
265+
eg_str = (
266+
"Both internal and user exceptions encountered. This group contains "
267+
"the internal exception(s)" + eg_substr + "."
268+
)
269+
user_error.__context__ = BaseExceptionGroup(eg_str, exceptions)
270+
user_error.__suppress_context__ = False
271+
_raise(user_error)
240272

241273
raise TrioWebsocketInternalError(
242274
"The trio-websocket API is not expected to raise multiple exceptions. "
@@ -576,7 +608,7 @@ def __init__(self, reason):
576608
:param reason:
577609
:type reason: CloseReason
578610
'''
579-
super().__init__()
611+
super().__init__(reason)
580612
self.reason = reason
581613

582614
def __repr__(self):
@@ -596,7 +628,7 @@ def __init__(self, status_code, headers, body):
596628
:param reason:
597629
:type reason: CloseReason
598630
'''
599-
super().__init__()
631+
super().__init__(status_code, headers, body)
600632
#: a 3 digit HTTP status code
601633
self.status_code = status_code
602634
#: a tuple of 2-tuples containing header key/value pairs

0 commit comments

Comments
 (0)