Skip to content

Commit 7222778

Browse files
authored
Merge pull request Textualize#9 from paul-ollis/xdist
Add support for pytest-xdist for **much faster** Textual tests.
2 parents 380386c + 4a9e1f7 commit 7222778

File tree

1 file changed

+150
-45
lines changed

1 file changed

+150
-45
lines changed

pytest_textual_snapshot.py

+150-45
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
22

33
import os
4+
import pickle
5+
import re
6+
import shutil
47
from dataclasses import dataclass
58
from datetime import datetime
69
from operator import attrgetter
710
from os import PathLike
811
from pathlib import Path, PurePath
12+
from tempfile import mkdtemp
913
from typing import Awaitable, Union, List, Optional, Callable, Iterable, TYPE_CHECKING
1014

1115
import pytest
@@ -16,14 +20,61 @@
1620
from jinja2 import Template
1721
from rich.console import Console
1822
from syrupy import SnapshotAssertion
23+
from syrupy.extensions.single_file import (
24+
SingleFileSnapshotExtension, WriteMode)
1925

2026
if TYPE_CHECKING:
27+
from _pytest.nodes import Item
2128
from textual.app import App
2229
from textual.pilot import Pilot
2330

24-
TEXTUAL_SNAPSHOT_SVG_KEY = pytest.StashKey[str]()
25-
TEXTUAL_ACTUAL_SVG_KEY = pytest.StashKey[str]()
26-
TEXTUAL_SNAPSHOT_PASS = pytest.StashKey[bool]()
31+
32+
class SVGImageExtension(SingleFileSnapshotExtension):
33+
_file_extension = "svg"
34+
_write_mode = WriteMode.TEXT
35+
36+
37+
class TemporaryDirectory:
38+
"""A temporary that survives forking.
39+
40+
This provides something akin to tempfile.TemporaryDirectory, but this
41+
version is not removed automatically when a process exits.
42+
"""
43+
44+
def __init__(self, name: str = ''):
45+
if name:
46+
self.name = name
47+
else:
48+
self.name = mkdtemp(None, None, None)
49+
50+
def cleanup(self):
51+
"""Clean up the temporary directory."""
52+
shutil.rmtree(self.name, ignore_errors=True)
53+
54+
55+
@dataclass
56+
class PseudoConsole:
57+
"""Something that looks enough like a Console to fill a Jinja2 template."""
58+
59+
legacy_windows: bool
60+
size: ConsoleDimensions
61+
62+
63+
@dataclass
64+
class PseudoApp:
65+
"""Something that looks enough like an App to fill a Jinja2 template.
66+
67+
This can be pickled OK, whereas the 'real' application involved in a test
68+
may contain unpickleable data.
69+
"""
70+
71+
console: PseudoConsole
72+
73+
74+
def rename_styles(svg: str, suffix: str) -> str:
75+
"""Rename style names to prevent clashes when combined in HTML report."""
76+
return re.sub(
77+
r'terminal-(\d+)-r(\d+)', rf'terminal-\1-r\2-{suffix}', svg)
2778

2879

2980
def pytest_addoption(parser):
@@ -39,6 +90,24 @@ def app_stash_key() -> pytest.StashKey:
3990
app_stash_key._key = pytest.StashKey[App]()
4091
return app_stash_key()
4192

93+
94+
def node_to_report_path(node: Item) -> Path:
95+
"""Generate a report file name for a test node."""
96+
tempdir = get_tempdir()
97+
path, _, name = node.reportinfo()
98+
temp = Path(path.parent)
99+
base = []
100+
while temp != temp.parent and temp.name != 'tests':
101+
base.append(temp.name)
102+
temp = temp.parent
103+
parts = []
104+
if base:
105+
parts.append('_'.join(reversed(base)))
106+
parts.append(path.name.replace('.', '_'))
107+
parts.append(name.replace('[', '_').replace(']', '_'))
108+
return Path(tempdir.name) / '_'.join(parts)
109+
110+
42111
@pytest.fixture
43112
def snap_compare(
44113
snapshot: SnapshotAssertion, request: FixtureRequest
@@ -48,6 +117,8 @@ def snap_compare(
48117
app with the output of the same app in the past. This is snapshot testing, and it
49118
used to catch regressions in output.
50119
"""
120+
# Switch so one file per snapshot, stored as plain simple SVG file.
121+
snapshot = snapshot.use_extension(SVGImageExtension)
51122

52123
def compare(
53124
app_path: str | PurePath,
@@ -93,17 +164,18 @@ def compare(
93164
terminal_size=terminal_size,
94165
run_before=run_before,
95166
)
167+
console = Console(legacy_windows=False, force_terminal=True)
168+
p_app = PseudoApp(PseudoConsole(console.legacy_windows, console.size))
169+
96170
result = snapshot == actual_screenshot
171+
expected_svg_text = str(snapshot)
172+
full_path, line_number, name = request.node.reportinfo()
97173

98-
if result is False:
99-
# The split and join below is a mad hack, sorry...
100-
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(
101-
str(snapshot).splitlines()[1:-1]
102-
)
103-
node.stash[TEXTUAL_ACTUAL_SVG_KEY] = actual_screenshot
104-
node.stash[app_stash_key()] = app
105-
else:
106-
node.stash[TEXTUAL_SNAPSHOT_PASS] = True
174+
data = (
175+
result, expected_svg_text, actual_screenshot, p_app, full_path,
176+
line_number, name)
177+
data_path = node_to_report_path(request.node)
178+
data_path.write_bytes(pickle.dumps(data))
107179

108180
return result
109181

@@ -125,37 +197,69 @@ class SvgSnapshotDiff:
125197
environment: dict
126198

127199

200+
def pytest_sessionstart(
201+
session: Session,
202+
) -> None:
203+
"""Set up a temporary directory to store snapshots.
204+
205+
The temporary directory name is stored in an environment vairable so that
206+
pytest-xdist worker child processes can retrieve it.
207+
"""
208+
if os.environ.get('PYTEST_XDIST_WORKER') is None:
209+
tempdir = TemporaryDirectory()
210+
os.environ['TEXTUAL_SNAPSHOT_TEMPDIR'] = tempdir.name
211+
212+
213+
def get_tempdir():
214+
"""Get the TemporaryDirectory."""
215+
return TemporaryDirectory(os.environ['TEXTUAL_SNAPSHOT_TEMPDIR'])
216+
217+
128218
def pytest_sessionfinish(
129219
session: Session,
130220
exitstatus: Union[int, ExitCode],
131221
) -> None:
132222
"""Called after whole test run finished, right before returning the exit status to the system.
133223
Generates the snapshot report and writes it to disk.
134224
"""
135-
diffs: List[SvgSnapshotDiff] = []
136-
num_snapshots_passing = 0
137-
138-
for item in session.items:
139-
# Grab the data our fixture attached to the pytest node
140-
num_snapshots_passing += int(item.stash.get(TEXTUAL_SNAPSHOT_PASS, False))
141-
snapshot_svg = item.stash.get(TEXTUAL_SNAPSHOT_SVG_KEY, None)
142-
actual_svg = item.stash.get(TEXTUAL_ACTUAL_SVG_KEY, None)
143-
app = item.stash.get(app_stash_key(), None)
144-
145-
if app:
146-
path, line_index, name = item.reportinfo()
147-
diffs.append(
148-
SvgSnapshotDiff(
149-
snapshot=str(snapshot_svg),
150-
actual=str(actual_svg),
151-
test_name=name,
152-
path=path,
153-
line_number=line_index + 1,
154-
app=app,
155-
environment=dict(os.environ),
156-
)
157-
)
225+
if os.environ.get('PYTEST_XDIST_WORKER') is None:
226+
tempdir = get_tempdir()
227+
diffs, num_snapshots_passing = retrieve_svg_diffs(tempdir)
228+
save_svg_diffs(diffs, session, num_snapshots_passing)
229+
tempdir.cleanup()
230+
231+
232+
def retrieve_svg_diffs(
233+
tempdir: TemporaryDirectory,
234+
) -> tuple[list[SvgSnapshotDiff], int]:
235+
"""Retrieve snapshot diffs from the temporary directory."""
236+
diffs: list[SvgSnapshotDiff] = []
237+
pass_count = 0
238+
239+
n = 0
240+
for data_path in Path(tempdir.name).iterdir():
241+
(passed, expect_svg_text, svg_text, app, full_path, line_index, name
242+
) = pickle.loads(data_path.read_bytes())
243+
pass_count += 1 if passed else 0
244+
if not passed:
245+
n += 1
246+
diffs.append(SvgSnapshotDiff(
247+
snapshot=rename_styles(str(expect_svg_text), f'exp{n}'),
248+
actual=rename_styles(svg_text, f'act{n}'),
249+
test_name=name,
250+
path=full_path,
251+
line_number=line_index + 1,
252+
app=app,
253+
environment=dict(os.environ)))
254+
return diffs, pass_count
255+
158256

257+
def save_svg_diffs(
258+
diffs: list[SvgSnapshotDiff],
259+
session: Session,
260+
num_snapshots_passing: int,
261+
) -> None:
262+
"""Save any detected differences to an HTML formatted report."""
159263
if diffs:
160264
diff_sort_key = attrgetter("test_name")
161265
diffs = sorted(diffs, key=diff_sort_key)
@@ -198,13 +302,14 @@ def pytest_terminal_summary(
198302
"""Add a section to terminal summary reporting.
199303
Displays the link to the snapshot report that was generated in a prior hook.
200304
"""
201-
diffs = getattr(config, "_textual_snapshots", None)
202-
console = Console(legacy_windows=False, force_terminal=True)
203-
if diffs:
204-
snapshot_report_location = config._textual_snapshot_html_report
205-
console.print("[b red]Textual Snapshot Report", style="red")
206-
console.print(
207-
f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
208-
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
209-
)
210-
console.print(f"[dim]{snapshot_report_location}\n")
305+
if os.environ.get('PYTEST_XDIST_WORKER') is None:
306+
diffs = getattr(config, "_textual_snapshots", None)
307+
console = Console(legacy_windows=False, force_terminal=True)
308+
if diffs:
309+
snapshot_report_location = config._textual_snapshot_html_report
310+
console.print("[b red]Textual Snapshot Report", style="red")
311+
console.print(
312+
f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
313+
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
314+
)
315+
console.print(f"[dim]{snapshot_report_location}\n")

0 commit comments

Comments
 (0)