Skip to content

Commit 5defd03

Browse files
committed
no-copy solution that completely hides the exceptiongroup in most cases
1 parent a4e0562 commit 5defd03

File tree

2 files changed

+67
-44
lines changed

2 files changed

+67
-44
lines changed

tests/test_connection.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock):
452452
Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup
453453
"""
454454
async def ki_raising_ping_handler(*args, **kwargs) -> None:
455-
print("raising ki")
456455
raise KeyboardInterrupt
457456
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler)
458457
async def handler(request):
@@ -474,27 +473,32 @@ async def handler(request):
474473
async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock):
475474
"""_reader_task._handle_ping_event triggers ValueError.
476475
user code also raises exception.
477-
internal exception is in __cause__ exceptiongroup and user exc is delivered
476+
internal exception is in __context__ exceptiongroup and user exc is delivered
478477
"""
479-
my_value_error = ValueError()
478+
internal_error = ValueError()
479+
internal_error.__context__ = TypeError()
480+
user_error = NameError()
481+
user_error_context = KeyError()
480482
async def raising_ping_event(*args, **kwargs) -> None:
481-
raise my_value_error
483+
raise internal_error
482484

483485
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event)
484486
async def handler(request):
485487
server_ws = await request.accept()
486488
await server_ws.ping(b"a")
487489

488490
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
489-
with pytest.raises(trio.TooSlowError) as exc_info:
491+
with pytest.raises(type(user_error)) as exc_info:
490492
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
491-
with trio.fail_after(1) as cs:
492-
cs.shield = True
493-
await trio.sleep(2)
493+
await trio.lowlevel.checkpoint()
494+
user_error.__context__ = user_error_context
495+
raise user_error
494496

495-
e_cause = exc_info.value.__cause__
496-
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
497-
assert my_value_error in e_cause.exceptions
497+
assert exc_info.value is user_error
498+
e_context = exc_info.value.__context__
499+
assert isinstance(e_context, BaseExceptionGroup)
500+
assert internal_error in e_context.exceptions
501+
assert user_error_context in e_context.exceptions
498502

499503
@fail_after(5)
500504
async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock):
@@ -528,14 +532,9 @@ async def handler(request):
528532
user_cancelled_context = e.__context__
529533
raise
530534

531-
# a copy of user_cancelled is reraised
532-
assert exc_info.value is not user_cancelled
533-
# with the same cause
535+
assert exc_info.value is user_cancelled
534536
assert exc_info.value.__cause__ is user_cancelled_cause
535-
# the context is the exception group, which contains the original user_cancelled
536-
assert exc_info.value.__context__.exceptions[1] is user_cancelled
537-
assert exc_info.value.__context__.exceptions[1].__cause__ is user_cancelled_cause
538-
assert exc_info.value.__context__.exceptions[1].__context__ is user_cancelled_context
537+
assert exc_info.value.__context__ is user_cancelled_context
539538

540539
def _trio_default_non_strict_exception_groups() -> bool:
541540
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
@@ -586,19 +585,9 @@ async def handler(request):
586585
except TypeError:
587586
raise e_primary from e_cause
588587
e = exc_info.value
589-
if _trio_default_non_strict_exception_groups():
590-
assert e is e_primary
591-
assert e.__cause__ is e_cause
592-
assert e.__context__ is e_context
593-
else:
594-
# a copy is reraised to avoid losing e_context
595-
assert e is not e_primary
596-
assert e.__cause__ is e_cause
597-
598-
# the nursery-internal group is injected as context
599-
assert isinstance(e.__context__, _TRIO_EXC_GROUP_TYPE)
600-
assert e.__context__.exceptions[0] is e_primary
601-
assert e.__context__.exceptions[0].__context__ is e_context
588+
assert e is e_primary
589+
assert e.__cause__ is e_cause
590+
assert e.__context__ is e_context
602591

603592
@fail_after(1)
604593
async def test_reject_handshake(nursery):

trio_websocket/_impl.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import ssl
1313
import struct
1414
import urllib.parse
15-
from typing import Iterable, List, Optional, Union
15+
from typing import Iterable, List, NoReturn, Optional, Union
1616

1717
import outcome
1818
import trio
@@ -192,10 +192,29 @@ async def _close_connection(connection: WebSocketConnection) -> None:
192192
except trio.TooSlowError:
193193
raise DisconnectionTimeout from None
194194

195+
def _raise(exc: BaseException) -> NoReturn:
196+
__tracebackhide__ = True
197+
context = exc.__context__
198+
try:
199+
raise exc
200+
finally:
201+
exc.__context__ = context
202+
del exc, context
203+
195204
connection: WebSocketConnection|None=None
196205
close_result: outcome.Maybe[None] | None = None
197206
user_error = None
198207

208+
# Unwrapping exception groups has a lot of pitfalls, one of them stemming from
209+
# the exception we raise also being inside the group that's set as the context.
210+
# This leads to loss of info unless properly handled.
211+
# See https://github.com/python-trio/flake8-async/issues/298
212+
# We therefore save the exception before raising it, and save our intended context,
213+
# so they can be modified in the `finally`.
214+
exc_to_raise = None
215+
exc_context = None
216+
# by avoiding use of `raise .. from ..` we leave the original __cause__
217+
199218
try:
200219
async with trio.open_nursery() as new_nursery:
201220
result = await outcome.acapture(_open_connection, new_nursery)
@@ -216,7 +235,7 @@ async def _close_connection(connection: WebSocketConnection) -> None:
216235
except _TRIO_EXC_GROUP_TYPE as e:
217236
# user_error, or exception bubbling up from _reader_task
218237
if len(e.exceptions) == 1:
219-
raise copy_exc(e.exceptions[0]) from e.exceptions[0].__cause__
238+
_raise(e.exceptions[0])
220239

221240
# contains at most 1 non-cancelled exceptions
222241
exception_to_raise: BaseException|None = None
@@ -229,25 +248,40 @@ async def _close_connection(connection: WebSocketConnection) -> None:
229248
else:
230249
if exception_to_raise is None:
231250
# all exceptions are cancelled
232-
# prefer raising the one from the user, for traceback reasons
251+
# we reraise the user exception and throw out internal
233252
if user_error is not None:
234-
# no reason to raise from e, just to include a bunch of extra
235-
# cancelleds.
236-
raise copy_exc(user_error) from user_error.__cause__
253+
_raise(user_error)
237254
# multiple internal Cancelled is not possible afaik
238-
raise copy_exc(e.exceptions[0]) from e # pragma: no cover
239-
raise copy_exc(exception_to_raise) from exception_to_raise.__cause__
255+
# but if so we just raise one of them
256+
_raise(e.exceptions[0])
257+
# raise the non-cancelled exception
258+
_raise(exception_to_raise)
240259

241-
# if we have any KeyboardInterrupt in the group, make sure to raise it.
260+
# if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt
261+
# with the group as cause & context
242262
for sub_exc in e.exceptions:
243263
if isinstance(sub_exc, KeyboardInterrupt):
244-
raise copy_exc(sub_exc) from e
264+
raise KeyboardInterrupt from e
245265

246266
# Both user code and internal code raised non-cancelled exceptions.
247-
# We "hide" the internal exception(s) in the __cause__ and surface
248-
# the user_error.
267+
# We set the context to be an exception group containing internal exceptions
268+
# and, if not None, `user_error.__context__`
249269
if user_error is not None:
250-
raise copy_exc(user_error) from e
270+
exceptions = [subexc for subexc in e.exceptions if subexc is not user_error]
271+
eg_substr = ''
272+
# there's technically loss of info here, with __suppress_context__=True you
273+
# still have original __context__ available, just not printed. But we delete
274+
# it completely because we can't partially suppress the group
275+
if user_error.__context__ is not None and not user_error.__suppress_context__:
276+
exceptions.append(user_error.__context__)
277+
eg_substr = ' and the context for the user exception'
278+
eg_str = (
279+
"Both internal and user exceptions encountered. This group contains "
280+
"the internal exception(s)" + eg_substr + "."
281+
)
282+
user_error.__context__ = BaseExceptionGroup(eg_str, exceptions)
283+
user_error.__suppress_context__ = False
284+
_raise(user_error)
251285

252286
raise TrioWebsocketInternalError(
253287
"The trio-websocket API is not expected to raise multiple exceptions. "

0 commit comments

Comments
 (0)