|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import collections |
3 | 4 | from collections.abc import Callable
|
4 |
| -from collections.abc import Generator |
| 5 | +import functools |
| 6 | +import sys |
5 | 7 | import threading
|
| 8 | +import time |
6 | 9 | import traceback
|
7 |
| -from types import TracebackType |
8 |
| -from typing import Any |
| 10 | +from typing import NamedTuple |
9 | 11 | from typing import TYPE_CHECKING
|
10 | 12 | import warnings
|
11 | 13 |
|
| 14 | +from _pytest.config import Config |
| 15 | +from _pytest.nodes import Item |
| 16 | +from _pytest.stash import StashKey |
| 17 | +from _pytest.tracemalloc import tracemalloc_message |
12 | 18 | import pytest
|
13 | 19 |
|
14 | 20 |
|
15 | 21 | if TYPE_CHECKING:
|
16 |
| - from typing_extensions import Self |
17 |
| - |
18 |
| - |
19 |
| -# Copied from cpython/Lib/test/support/threading_helper.py, with modifications. |
20 |
| -class catch_threading_exception: |
21 |
| - """Context manager catching threading.Thread exception using |
22 |
| - threading.excepthook. |
23 |
| -
|
24 |
| - Storing exc_value using a custom hook can create a reference cycle. The |
25 |
| - reference cycle is broken explicitly when the context manager exits. |
26 |
| -
|
27 |
| - Storing thread using a custom hook can resurrect it if it is set to an |
28 |
| - object which is being finalized. Exiting the context manager clears the |
29 |
| - stored object. |
30 |
| -
|
31 |
| - Usage: |
32 |
| - with threading_helper.catch_threading_exception() as cm: |
33 |
| - # code spawning a thread which raises an exception |
34 |
| - ... |
35 |
| - # check the thread exception: use cm.args |
36 |
| - ... |
37 |
| - # cm.args attribute no longer exists at this point |
38 |
| - # (to break a reference cycle) |
39 |
| - """ |
40 |
| - |
41 |
| - def __init__(self) -> None: |
42 |
| - self.args: threading.ExceptHookArgs | None = None |
43 |
| - self._old_hook: Callable[[threading.ExceptHookArgs], Any] | None = None |
44 |
| - |
45 |
| - def _hook(self, args: threading.ExceptHookArgs) -> None: |
46 |
| - self.args = args |
47 |
| - |
48 |
| - def __enter__(self) -> Self: |
49 |
| - self._old_hook = threading.excepthook |
50 |
| - threading.excepthook = self._hook |
51 |
| - return self |
52 |
| - |
53 |
| - def __exit__( |
54 |
| - self, |
55 |
| - exc_type: type[BaseException] | None, |
56 |
| - exc_val: BaseException | None, |
57 |
| - exc_tb: TracebackType | None, |
58 |
| - ) -> None: |
59 |
| - assert self._old_hook is not None |
60 |
| - threading.excepthook = self._old_hook |
61 |
| - self._old_hook = None |
62 |
| - del self.args |
63 |
| - |
64 |
| - |
65 |
| -def thread_exception_runtest_hook() -> Generator[None]: |
66 |
| - with catch_threading_exception() as cm: |
| 22 | + pass |
| 23 | + |
| 24 | +if sys.version_info < (3, 11): |
| 25 | + from exceptiongroup import ExceptionGroup |
| 26 | + |
| 27 | + |
| 28 | +def join_threads() -> None: |
| 29 | + start = time.monotonic() |
| 30 | + current_thread = threading.current_thread() |
| 31 | + # This function is executed right at the end of the pytest run, just |
| 32 | + # before we return an exit code, which is where the interpreter joins |
| 33 | + # any remaining non-daemonic threads anyway, so it's ok to join all the |
| 34 | + # threads. However there might be threads that depend on some shutdown |
| 35 | + # signal that happens after pytest finishes, so we want to limit the |
| 36 | + # join time somewhat. A one second timeout seems reasonable. |
| 37 | + timeout = 1 |
| 38 | + for thread in threading.enumerate(): |
| 39 | + if thread is not current_thread and not thread.daemon: |
| 40 | + # TODO: raise an error/warning if there's dangling threads. |
| 41 | + thread.join(timeout - (time.monotonic() - start)) |
| 42 | + |
| 43 | + |
| 44 | +class ThreadExceptionMeta(NamedTuple): |
| 45 | + msg: str |
| 46 | + cause_msg: str |
| 47 | + exc_value: BaseException | None |
| 48 | + |
| 49 | + |
| 50 | +thread_exceptions: StashKey[collections.deque[ThreadExceptionMeta | BaseException]] = ( |
| 51 | + StashKey() |
| 52 | +) |
| 53 | + |
| 54 | + |
| 55 | +def collect_thread_exception(config: Config) -> None: |
| 56 | + pop_thread_exception = config.stash[thread_exceptions].pop |
| 57 | + errors: list[pytest.PytestUnhandledThreadExceptionWarning | RuntimeError] = [] |
| 58 | + meta = None |
| 59 | + hook_error = None |
| 60 | + try: |
| 61 | + while True: |
| 62 | + try: |
| 63 | + meta = pop_thread_exception() |
| 64 | + except IndexError: |
| 65 | + break |
| 66 | + |
| 67 | + if isinstance(meta, BaseException): |
| 68 | + hook_error = RuntimeError("Failed to process thread exception") |
| 69 | + hook_error.__cause__ = meta |
| 70 | + errors.append(hook_error) |
| 71 | + continue |
| 72 | + |
| 73 | + msg = meta.msg |
| 74 | + try: |
| 75 | + warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg)) |
| 76 | + except pytest.PytestUnhandledThreadExceptionWarning as e: |
| 77 | + # This except happens when the warning is treated as an error (e.g. `-Werror`). |
| 78 | + if meta.exc_value is not None: |
| 79 | + # Exceptions have a better way to show the traceback, but |
| 80 | + # warnings do not, so hide the traceback from the msg and |
| 81 | + # set the cause so the traceback shows up in the right place. |
| 82 | + e.args = (meta.cause_msg,) |
| 83 | + e.__cause__ = meta.exc_value |
| 84 | + errors.append(e) |
| 85 | + |
| 86 | + if len(errors) == 1: |
| 87 | + raise errors[0] |
| 88 | + if errors: |
| 89 | + raise ExceptionGroup("multiple thread exception warnings", errors) |
| 90 | + finally: |
| 91 | + del errors, meta, hook_error |
| 92 | + |
| 93 | + |
| 94 | +def cleanup( |
| 95 | + *, config: Config, prev_hook: Callable[[threading.ExceptHookArgs], object] |
| 96 | +) -> None: |
| 97 | + try: |
67 | 98 | try:
|
68 |
| - yield |
| 99 | + join_threads() |
| 100 | + collect_thread_exception(config) |
69 | 101 | finally:
|
70 |
| - if cm.args: |
71 |
| - thread_name = ( |
72 |
| - "<unknown>" if cm.args.thread is None else cm.args.thread.name |
73 |
| - ) |
74 |
| - msg = f"Exception in thread {thread_name}\n\n" |
75 |
| - msg += "".join( |
76 |
| - traceback.format_exception( |
77 |
| - cm.args.exc_type, |
78 |
| - cm.args.exc_value, |
79 |
| - cm.args.exc_traceback, |
80 |
| - ) |
81 |
| - ) |
82 |
| - warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg)) |
83 |
| - |
84 |
| - |
85 |
| -@pytest.hookimpl(wrapper=True, trylast=True) |
86 |
| -def pytest_runtest_setup() -> Generator[None]: |
87 |
| - yield from thread_exception_runtest_hook() |
88 |
| - |
89 |
| - |
90 |
| -@pytest.hookimpl(wrapper=True, tryfirst=True) |
91 |
| -def pytest_runtest_call() -> Generator[None]: |
92 |
| - yield from thread_exception_runtest_hook() |
93 |
| - |
94 |
| - |
95 |
| -@pytest.hookimpl(wrapper=True, tryfirst=True) |
96 |
| -def pytest_runtest_teardown() -> Generator[None]: |
97 |
| - yield from thread_exception_runtest_hook() |
| 102 | + threading.excepthook = prev_hook |
| 103 | + finally: |
| 104 | + del config.stash[thread_exceptions] |
| 105 | + |
| 106 | + |
| 107 | +def thread_exception_hook( |
| 108 | + args: threading.ExceptHookArgs, |
| 109 | + /, |
| 110 | + *, |
| 111 | + append: Callable[[ThreadExceptionMeta | BaseException], object], |
| 112 | +) -> None: |
| 113 | + try: |
| 114 | + # we need to compute these strings here as they might change after |
| 115 | + # the excepthook finishes and before the metadata object is |
| 116 | + # collected by a pytest hook |
| 117 | + thread_name = "<unknown>" if args.thread is None else args.thread.name |
| 118 | + summary = f"Exception in thread {thread_name}" |
| 119 | + traceback_message = "\n\n" + "".join( |
| 120 | + traceback.format_exception( |
| 121 | + args.exc_type, |
| 122 | + args.exc_value, |
| 123 | + args.exc_traceback, |
| 124 | + ) |
| 125 | + ) |
| 126 | + tracemalloc_tb = "\n" + tracemalloc_message(args.thread) |
| 127 | + msg = summary + traceback_message + tracemalloc_tb |
| 128 | + cause_msg = summary + tracemalloc_tb |
| 129 | + |
| 130 | + append( |
| 131 | + ThreadExceptionMeta( |
| 132 | + # Compute these strings here as they might change later |
| 133 | + msg=msg, |
| 134 | + cause_msg=cause_msg, |
| 135 | + exc_value=args.exc_value, |
| 136 | + ) |
| 137 | + ) |
| 138 | + except BaseException as e: |
| 139 | + append(e) |
| 140 | + # Raising this will cause the exception to be logged twice, once in our |
| 141 | + # collect_thread_exception and once by sys.excepthook |
| 142 | + # which is fine - this should never happen anyway and if it does |
| 143 | + # it should probably be reported as a pytest bug. |
| 144 | + raise |
| 145 | + |
| 146 | + |
| 147 | +def pytest_configure(config: Config) -> None: |
| 148 | + prev_hook = threading.excepthook |
| 149 | + deque: collections.deque[ThreadExceptionMeta | BaseException] = collections.deque() |
| 150 | + config.stash[thread_exceptions] = deque |
| 151 | + config.add_cleanup(functools.partial(cleanup, config=config, prev_hook=prev_hook)) |
| 152 | + threading.excepthook = functools.partial(thread_exception_hook, append=deque.append) |
| 153 | + |
| 154 | + |
| 155 | +@pytest.hookimpl(trylast=True) |
| 156 | +def pytest_runtest_setup(item: Item) -> None: |
| 157 | + collect_thread_exception(item.config) |
| 158 | + |
| 159 | + |
| 160 | +@pytest.hookimpl(trylast=True) |
| 161 | +def pytest_runtest_call(item: Item) -> None: |
| 162 | + collect_thread_exception(item.config) |
| 163 | + |
| 164 | + |
| 165 | +@pytest.hookimpl(trylast=True) |
| 166 | +def pytest_runtest_teardown(item: Item) -> None: |
| 167 | + collect_thread_exception(item.config) |
0 commit comments