Skip to content

Commit 21d2989

Browse files
authored
Merge pull request #164 from ConorMacBride/fix-classes
Add support for classes with pytest 7
2 parents 74d9e98 + 19dac01 commit 21d2989

File tree

4 files changed

+162
-131
lines changed

4 files changed

+162
-131
lines changed

.github/workflows/test_and_publish.yml

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ jobs:
3333
- linux: py38-test-mpl33
3434
- linux: py39-test-mpl34
3535
- linux: py310-test-mpl35
36+
# Test different versions of pytest
37+
- linux: py310-test-mpl35-pytestdev
38+
- linux: py310-test-mpl35-pytest62
39+
- linux: py38-test-mpl35-pytest54
3640
coverage: 'codecov'
3741

3842
publish:

pytest_mpl/plugin.py

+108-131
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@
3333
import json
3434
import shutil
3535
import hashlib
36-
import inspect
3736
import logging
3837
import tempfile
3938
import warnings
4039
import contextlib
4140
from pathlib import Path
42-
from functools import wraps
4341
from urllib.request import urlopen
4442

4543
import pytest
@@ -83,6 +81,14 @@ def pathify(path):
8381
return Path(path + ext)
8482

8583

84+
def _pytest_pyfunc_call(obj, pyfuncitem):
85+
testfunction = pyfuncitem.obj
86+
funcargs = pyfuncitem.funcargs
87+
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
88+
obj.result = testfunction(**testargs)
89+
return True
90+
91+
8692
def pytest_report_header(config, startdir):
8793
import matplotlib
8894
import matplotlib.ft2font
@@ -211,13 +217,11 @@ def close_mpl_figure(fig):
211217
plt.close(fig)
212218

213219

214-
def get_marker(item, marker_name):
215-
if hasattr(item, 'get_closest_marker'):
216-
return item.get_closest_marker(marker_name)
217-
else:
218-
# "item.keywords.get" was deprecated in pytest 3.6
219-
# See https://docs.pytest.org/en/latest/mark.html#updating-code
220-
return item.keywords.get(marker_name)
220+
def get_compare(item):
221+
"""
222+
Return the mpl_image_compare marker for the given item.
223+
"""
224+
return item.get_closest_marker("mpl_image_compare")
221225

222226

223227
def path_is_not_none(apath):
@@ -278,20 +282,14 @@ def __init__(self,
278282
logging.basicConfig(level=level)
279283
self.logger = logging.getLogger('pytest-mpl')
280284

281-
def get_compare(self, item):
282-
"""
283-
Return the mpl_image_compare marker for the given item.
284-
"""
285-
return get_marker(item, 'mpl_image_compare')
286-
287285
def generate_filename(self, item):
288286
"""
289287
Given a pytest item, generate the figure filename.
290288
"""
291289
if self.config.getini('mpl-use-full-test-name'):
292290
filename = self.generate_test_name(item) + '.png'
293291
else:
294-
compare = self.get_compare(item)
292+
compare = get_compare(item)
295293
# Find test name to use as plot name
296294
filename = compare.kwargs.get('filename', None)
297295
if filename is None:
@@ -304,7 +302,11 @@ def generate_test_name(self, item):
304302
"""
305303
Generate a unique name for the hash for this test.
306304
"""
307-
return f"{item.module.__name__}.{item.name}"
305+
if item.cls is not None:
306+
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
307+
else:
308+
name = f"{item.module.__name__}.{item.name}"
309+
return name
308310

309311
def make_test_results_dir(self, item):
310312
"""
@@ -319,7 +321,7 @@ def baseline_directory_specified(self, item):
319321
"""
320322
Returns `True` if a non-default baseline directory is specified.
321323
"""
322-
compare = self.get_compare(item)
324+
compare = get_compare(item)
323325
item_baseline_dir = compare.kwargs.get('baseline_dir', None)
324326
return item_baseline_dir or self.baseline_dir or self.baseline_relative_dir
325327

@@ -330,7 +332,7 @@ def get_baseline_directory(self, item):
330332
Using the global and per-test configuration return the absolute
331333
baseline dir, if the baseline file is local else return base URL.
332334
"""
333-
compare = self.get_compare(item)
335+
compare = get_compare(item)
334336
baseline_dir = compare.kwargs.get('baseline_dir', None)
335337
if baseline_dir is None:
336338
if self.baseline_dir is None:
@@ -394,7 +396,7 @@ def generate_baseline_image(self, item, fig):
394396
"""
395397
Generate reference figures.
396398
"""
397-
compare = self.get_compare(item)
399+
compare = get_compare(item)
398400
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
399401

400402
if not os.path.exists(self.generate_dir):
@@ -413,7 +415,7 @@ def generate_image_hash(self, item, fig):
413415
For a `matplotlib.figure.Figure`, returns the SHA256 hash as a hexadecimal
414416
string.
415417
"""
416-
compare = self.get_compare(item)
418+
compare = get_compare(item)
417419
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
418420

419421
imgdata = io.BytesIO()
@@ -436,7 +438,7 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
436438
if summary is None:
437439
summary = {}
438440

439-
compare = self.get_compare(item)
441+
compare = get_compare(item)
440442
tolerance = compare.kwargs.get('tolerance', 2)
441443
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
442444

@@ -510,7 +512,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
510512
if summary is None:
511513
summary = {}
512514

513-
compare = self.get_compare(item)
515+
compare = get_compare(item)
514516
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
515517

516518
if not self.results_hash_library_name:
@@ -582,11 +584,13 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
582584
return
583585
return summary['status_msg']
584586

585-
def pytest_runtest_setup(self, item): # noqa
587+
@pytest.hookimpl(hookwrapper=True)
588+
def pytest_runtest_call(self, item): # noqa
586589

587-
compare = self.get_compare(item)
590+
compare = get_compare(item)
588591

589592
if compare is None:
593+
yield
590594
return
591595

592596
import matplotlib.pyplot as plt
@@ -600,95 +604,82 @@ def pytest_runtest_setup(self, item): # noqa
600604
remove_text = compare.kwargs.get('remove_text', False)
601605
backend = compare.kwargs.get('backend', 'agg')
602606

603-
original = item.function
604-
605-
@wraps(item.function)
606-
def item_function_wrapper(*args, **kwargs):
607-
608-
with plt.style.context(style, after_reset=True), switch_backend(backend):
609-
610-
# Run test and get figure object
611-
if inspect.ismethod(original): # method
612-
# In some cases, for example if setup_method is used,
613-
# original appears to belong to an instance of the test
614-
# class that is not the same as args[0], and args[0] is the
615-
# one that has the correct attributes set up from setup_method
616-
# so we ignore original.__self__ and use args[0] instead.
617-
fig = original.__func__(*args, **kwargs)
618-
else: # function
619-
fig = original(*args, **kwargs)
620-
621-
if remove_text:
622-
remove_ticks_and_titles(fig)
623-
624-
test_name = self.generate_test_name(item)
625-
result_dir = self.make_test_results_dir(item)
626-
627-
summary = {
628-
'status': None,
629-
'image_status': None,
630-
'hash_status': None,
631-
'status_msg': None,
632-
'baseline_image': None,
633-
'diff_image': None,
634-
'rms': None,
635-
'tolerance': None,
636-
'result_image': None,
637-
'baseline_hash': None,
638-
'result_hash': None,
639-
}
640-
641-
# What we do now depends on whether we are generating the
642-
# reference images or simply running the test.
643-
if self.generate_dir is not None:
644-
summary['status'] = 'skipped'
645-
summary['image_status'] = 'generated'
646-
summary['status_msg'] = 'Skipped test, since generating image.'
647-
generate_image = self.generate_baseline_image(item, fig)
648-
if self.results_always: # Make baseline image available in HTML
649-
result_image = (result_dir / "baseline.png").absolute()
650-
shutil.copy(generate_image, result_image)
651-
summary['baseline_image'] = \
652-
result_image.relative_to(self.results_dir).as_posix()
653-
654-
if self.generate_hash_library is not None:
655-
summary['hash_status'] = 'generated'
656-
image_hash = self.generate_image_hash(item, fig)
657-
self._generated_hash_library[test_name] = image_hash
658-
summary['baseline_hash'] = image_hash
659-
660-
# Only test figures if not generating images
661-
if self.generate_dir is None:
662-
# Compare to hash library
663-
if self.hash_library or compare.kwargs.get('hash_library', None):
664-
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)
665-
666-
# Compare against a baseline if specified
667-
else:
668-
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)
669-
670-
close_mpl_figure(fig)
671-
672-
if msg is None:
673-
if not self.results_always:
674-
shutil.rmtree(result_dir)
675-
for image_type in ['baseline_image', 'diff_image', 'result_image']:
676-
summary[image_type] = None # image no longer exists
677-
else:
678-
self._test_results[test_name] = summary
679-
pytest.fail(msg, pytrace=False)
607+
with plt.style.context(style, after_reset=True), switch_backend(backend):
608+
609+
# Run test and get figure object
610+
yield
611+
fig = self.result
612+
613+
if remove_text:
614+
remove_ticks_and_titles(fig)
615+
616+
test_name = self.generate_test_name(item)
617+
result_dir = self.make_test_results_dir(item)
618+
619+
summary = {
620+
'status': None,
621+
'image_status': None,
622+
'hash_status': None,
623+
'status_msg': None,
624+
'baseline_image': None,
625+
'diff_image': None,
626+
'rms': None,
627+
'tolerance': None,
628+
'result_image': None,
629+
'baseline_hash': None,
630+
'result_hash': None,
631+
}
632+
633+
# What we do now depends on whether we are generating the
634+
# reference images or simply running the test.
635+
if self.generate_dir is not None:
636+
summary['status'] = 'skipped'
637+
summary['image_status'] = 'generated'
638+
summary['status_msg'] = 'Skipped test, since generating image.'
639+
generate_image = self.generate_baseline_image(item, fig)
640+
if self.results_always: # Make baseline image available in HTML
641+
result_image = (result_dir / "baseline.png").absolute()
642+
shutil.copy(generate_image, result_image)
643+
summary['baseline_image'] = \
644+
result_image.relative_to(self.results_dir).as_posix()
645+
646+
if self.generate_hash_library is not None:
647+
summary['hash_status'] = 'generated'
648+
image_hash = self.generate_image_hash(item, fig)
649+
self._generated_hash_library[test_name] = image_hash
650+
summary['baseline_hash'] = image_hash
651+
652+
# Only test figures if not generating images
653+
if self.generate_dir is None:
654+
# Compare to hash library
655+
if self.hash_library or compare.kwargs.get('hash_library', None):
656+
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)
657+
658+
# Compare against a baseline if specified
659+
else:
660+
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)
680661

681662
close_mpl_figure(fig)
682663

683-
self._test_results[test_name] = summary
664+
if msg is None:
665+
if not self.results_always:
666+
shutil.rmtree(result_dir)
667+
for image_type in ['baseline_image', 'diff_image', 'result_image']:
668+
summary[image_type] = None # image no longer exists
669+
else:
670+
self._test_results[test_name] = summary
671+
pytest.fail(msg, pytrace=False)
684672

685-
if summary['status'] == 'skipped':
686-
pytest.skip(summary['status_msg'])
673+
close_mpl_figure(fig)
687674

688-
if item.cls is not None:
689-
setattr(item.cls, item.function.__name__, item_function_wrapper)
690-
else:
691-
item.obj = item_function_wrapper
675+
self._test_results[test_name] = summary
676+
677+
if summary['status'] == 'skipped':
678+
pytest.skip(summary['status_msg'])
679+
680+
@pytest.hookimpl(tryfirst=True)
681+
def pytest_pyfunc_call(self, pyfuncitem):
682+
return _pytest_pyfunc_call(self, pyfuncitem)
692683

693684
def generate_summary_json(self):
694685
json_file = self.results_dir / 'results.json'
@@ -742,26 +733,12 @@ class FigureCloser:
742733
def __init__(self, config):
743734
self.config = config
744735

745-
def pytest_runtest_setup(self, item):
746-
747-
compare = get_marker(item, 'mpl_image_compare')
748-
749-
if compare is None:
750-
return
751-
752-
original = item.function
753-
754-
@wraps(item.function)
755-
def item_function_wrapper(*args, **kwargs):
756-
757-
if inspect.ismethod(original): # method
758-
fig = original.__func__(*args, **kwargs)
759-
else: # function
760-
fig = original(*args, **kwargs)
761-
762-
close_mpl_figure(fig)
736+
@pytest.hookimpl(hookwrapper=True)
737+
def pytest_runtest_call(self, item):
738+
yield
739+
if get_compare(item) is not None:
740+
close_mpl_figure(self.result)
763741

764-
if item.cls is not None:
765-
setattr(item.cls, item.function.__name__, item_function_wrapper)
766-
else:
767-
item.obj = item_function_wrapper
742+
@pytest.hookimpl(tryfirst=True)
743+
def pytest_pyfunc_call(self, pyfuncitem):
744+
return _pytest_pyfunc_call(self, pyfuncitem)

0 commit comments

Comments
 (0)