1
1
from __future__ import annotations
2
2
3
3
import os
4
+ import pickle
5
+ import re
6
+ import shutil
4
7
from dataclasses import dataclass
5
8
from datetime import datetime
6
9
from operator import attrgetter
7
10
from os import PathLike
8
11
from pathlib import Path , PurePath
12
+ from tempfile import mkdtemp
9
13
from typing import Awaitable , Union , List , Optional , Callable , Iterable , TYPE_CHECKING
10
14
11
15
import pytest
16
20
from jinja2 import Template
17
21
from rich .console import Console
18
22
from syrupy import SnapshotAssertion
23
+ from syrupy .extensions .single_file import (
24
+ SingleFileSnapshotExtension , WriteMode )
19
25
20
26
if TYPE_CHECKING :
27
+ from _pytest .nodes import Item
21
28
from textual .app import App
22
29
from textual .pilot import Pilot
23
30
24
- TEXTUAL_SNAPSHOT_SVG_KEY = pytest .StashKey [str ]()
25
- TEXTUAL_ACTUAL_SVG_KEY = pytest .StashKey [str ]()
26
- TEXTUAL_SNAPSHOT_PASS = pytest .StashKey [bool ]()
31
+
32
+ class SVGImageExtension (SingleFileSnapshotExtension ):
33
+ _file_extension = "svg"
34
+ _write_mode = WriteMode .TEXT
35
+
36
+
37
+ class TemporaryDirectory :
38
+ """A temporary that survives forking.
39
+
40
+ This provides something akin to tempfile.TemporaryDirectory, but this
41
+ version is not removed automatically when a process exits.
42
+ """
43
+
44
+ def __init__ (self , name : str = '' ):
45
+ if name :
46
+ self .name = name
47
+ else :
48
+ self .name = mkdtemp (None , None , None )
49
+
50
+ def cleanup (self ):
51
+ """Clean up the temporary directory."""
52
+ shutil .rmtree (self .name , ignore_errors = True )
53
+
54
+
55
+ @dataclass
56
+ class PseudoConsole :
57
+ """Something that looks enough like a Console to fill a Jinja2 template."""
58
+
59
+ legacy_windows : bool
60
+ size : ConsoleDimensions
61
+
62
+
63
+ @dataclass
64
+ class PseudoApp :
65
+ """Something that looks enough like an App to fill a Jinja2 template.
66
+
67
+ This can be pickled OK, whereas the 'real' application involved in a test
68
+ may contain unpickleable data.
69
+ """
70
+
71
+ console : PseudoConsole
72
+
73
+
74
+ def rename_styles (svg : str , suffix : str ) -> str :
75
+ """Rename style names to prevent clashes when combined in HTML report."""
76
+ return re .sub (
77
+ r'terminal-(\d+)-r(\d+)' , rf'terminal-\1-r\2-{ suffix } ' , svg )
27
78
28
79
29
80
def pytest_addoption (parser ):
@@ -39,6 +90,24 @@ def app_stash_key() -> pytest.StashKey:
39
90
app_stash_key ._key = pytest .StashKey [App ]()
40
91
return app_stash_key ()
41
92
93
+
94
+ def node_to_report_path (node : Item ) -> Path :
95
+ """Generate a report file name for a test node."""
96
+ tempdir = get_tempdir ()
97
+ path , _ , name = node .reportinfo ()
98
+ temp = Path (path .parent )
99
+ base = []
100
+ while temp != temp .parent and temp .name != 'tests' :
101
+ base .append (temp .name )
102
+ temp = temp .parent
103
+ parts = []
104
+ if base :
105
+ parts .append ('_' .join (reversed (base )))
106
+ parts .append (path .name .replace ('.' , '_' ))
107
+ parts .append (name .replace ('[' , '_' ).replace (']' , '_' ))
108
+ return Path (tempdir .name ) / '_' .join (parts )
109
+
110
+
42
111
@pytest .fixture
43
112
def snap_compare (
44
113
snapshot : SnapshotAssertion , request : FixtureRequest
@@ -48,6 +117,8 @@ def snap_compare(
48
117
app with the output of the same app in the past. This is snapshot testing, and it
49
118
used to catch regressions in output.
50
119
"""
120
+ # Switch so one file per snapshot, stored as plain simple SVG file.
121
+ snapshot = snapshot .use_extension (SVGImageExtension )
51
122
52
123
def compare (
53
124
app_path : str | PurePath ,
@@ -93,17 +164,18 @@ def compare(
93
164
terminal_size = terminal_size ,
94
165
run_before = run_before ,
95
166
)
167
+ console = Console (legacy_windows = False , force_terminal = True )
168
+ p_app = PseudoApp (PseudoConsole (console .legacy_windows , console .size ))
169
+
96
170
result = snapshot == actual_screenshot
171
+ expected_svg_text = str (snapshot )
172
+ full_path , line_number , name = request .node .reportinfo ()
97
173
98
- if result is False :
99
- # The split and join below is a mad hack, sorry...
100
- node .stash [TEXTUAL_SNAPSHOT_SVG_KEY ] = "\n " .join (
101
- str (snapshot ).splitlines ()[1 :- 1 ]
102
- )
103
- node .stash [TEXTUAL_ACTUAL_SVG_KEY ] = actual_screenshot
104
- node .stash [app_stash_key ()] = app
105
- else :
106
- node .stash [TEXTUAL_SNAPSHOT_PASS ] = True
174
+ data = (
175
+ result , expected_svg_text , actual_screenshot , p_app , full_path ,
176
+ line_number , name )
177
+ data_path = node_to_report_path (request .node )
178
+ data_path .write_bytes (pickle .dumps (data ))
107
179
108
180
return result
109
181
@@ -125,37 +197,69 @@ class SvgSnapshotDiff:
125
197
environment : dict
126
198
127
199
200
+ def pytest_sessionstart (
201
+ session : Session ,
202
+ ) -> None :
203
+ """Set up a temporary directory to store snapshots.
204
+
205
+ The temporary directory name is stored in an environment vairable so that
206
+ pytest-xdist worker child processes can retrieve it.
207
+ """
208
+ if os .environ .get ('PYTEST_XDIST_WORKER' ) is None :
209
+ tempdir = TemporaryDirectory ()
210
+ os .environ ['TEXTUAL_SNAPSHOT_TEMPDIR' ] = tempdir .name
211
+
212
+
213
+ def get_tempdir ():
214
+ """Get the TemporaryDirectory."""
215
+ return TemporaryDirectory (os .environ ['TEXTUAL_SNAPSHOT_TEMPDIR' ])
216
+
217
+
128
218
def pytest_sessionfinish (
129
219
session : Session ,
130
220
exitstatus : Union [int , ExitCode ],
131
221
) -> None :
132
222
"""Called after whole test run finished, right before returning the exit status to the system.
133
223
Generates the snapshot report and writes it to disk.
134
224
"""
135
- diffs : List [SvgSnapshotDiff ] = []
136
- num_snapshots_passing = 0
137
-
138
- for item in session .items :
139
- # Grab the data our fixture attached to the pytest node
140
- num_snapshots_passing += int (item .stash .get (TEXTUAL_SNAPSHOT_PASS , False ))
141
- snapshot_svg = item .stash .get (TEXTUAL_SNAPSHOT_SVG_KEY , None )
142
- actual_svg = item .stash .get (TEXTUAL_ACTUAL_SVG_KEY , None )
143
- app = item .stash .get (app_stash_key (), None )
144
-
145
- if app :
146
- path , line_index , name = item .reportinfo ()
147
- diffs .append (
148
- SvgSnapshotDiff (
149
- snapshot = str (snapshot_svg ),
150
- actual = str (actual_svg ),
151
- test_name = name ,
152
- path = path ,
153
- line_number = line_index + 1 ,
154
- app = app ,
155
- environment = dict (os .environ ),
156
- )
157
- )
225
+ if os .environ .get ('PYTEST_XDIST_WORKER' ) is None :
226
+ tempdir = get_tempdir ()
227
+ diffs , num_snapshots_passing = retrieve_svg_diffs (tempdir )
228
+ save_svg_diffs (diffs , session , num_snapshots_passing )
229
+ tempdir .cleanup ()
230
+
231
+
232
+ def retrieve_svg_diffs (
233
+ tempdir : TemporaryDirectory ,
234
+ ) -> tuple [list [SvgSnapshotDiff ], int ]:
235
+ """Retrieve snapshot diffs from the temporary directory."""
236
+ diffs : list [SvgSnapshotDiff ] = []
237
+ pass_count = 0
238
+
239
+ n = 0
240
+ for data_path in Path (tempdir .name ).iterdir ():
241
+ (passed , expect_svg_text , svg_text , app , full_path , line_index , name
242
+ ) = pickle .loads (data_path .read_bytes ())
243
+ pass_count += 1 if passed else 0
244
+ if not passed :
245
+ n += 1
246
+ diffs .append (SvgSnapshotDiff (
247
+ snapshot = rename_styles (str (expect_svg_text ), f'exp{ n } ' ),
248
+ actual = rename_styles (svg_text , f'act{ n } ' ),
249
+ test_name = name ,
250
+ path = full_path ,
251
+ line_number = line_index + 1 ,
252
+ app = app ,
253
+ environment = dict (os .environ )))
254
+ return diffs , pass_count
255
+
158
256
257
+ def save_svg_diffs (
258
+ diffs : list [SvgSnapshotDiff ],
259
+ session : Session ,
260
+ num_snapshots_passing : int ,
261
+ ) -> None :
262
+ """Save any detected differences to an HTML formatted report."""
159
263
if diffs :
160
264
diff_sort_key = attrgetter ("test_name" )
161
265
diffs = sorted (diffs , key = diff_sort_key )
@@ -198,13 +302,14 @@ def pytest_terminal_summary(
198
302
"""Add a section to terminal summary reporting.
199
303
Displays the link to the snapshot report that was generated in a prior hook.
200
304
"""
201
- diffs = getattr (config , "_textual_snapshots" , None )
202
- console = Console (legacy_windows = False , force_terminal = True )
203
- if diffs :
204
- snapshot_report_location = config ._textual_snapshot_html_report
205
- console .print ("[b red]Textual Snapshot Report" , style = "red" )
206
- console .print (
207
- f"\n [black on red]{ len (diffs )} mismatched snapshots[/]\n "
208
- f"\n [b]View the [link=file://{ snapshot_report_location } ]failure report[/].\n "
209
- )
210
- console .print (f"[dim]{ snapshot_report_location } \n " )
305
+ if os .environ .get ('PYTEST_XDIST_WORKER' ) is None :
306
+ diffs = getattr (config , "_textual_snapshots" , None )
307
+ console = Console (legacy_windows = False , force_terminal = True )
308
+ if diffs :
309
+ snapshot_report_location = config ._textual_snapshot_html_report
310
+ console .print ("[b red]Textual Snapshot Report" , style = "red" )
311
+ console .print (
312
+ f"\n [black on red]{ len (diffs )} mismatched snapshots[/]\n "
313
+ f"\n [b]View the [link=file://{ snapshot_report_location } ]failure report[/].\n "
314
+ )
315
+ console .print (f"[dim]{ snapshot_report_location } \n " )
0 commit comments