Skip to content

Commit 472d92a

Browse files
committed
Add support for comparing multiple baseline images
The failing compare with the lowest rms value will be used for the summary. A shape-mismatch has infinitely low "rms" and will be preferred over any comparison mismatch.
1 parent e387618 commit 472d92a

File tree

1 file changed

+94
-51
lines changed

1 file changed

+94
-51
lines changed

pytest_mpl/plugin.py

+94-51
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import io
3232
import os
33+
import glob
3334
import json
3435
import shutil
3536
import hashlib
@@ -370,40 +371,64 @@ def _download_file(self, baseline, filename):
370371
tmpfile.write(content)
371372
return Path(filename)
372373

373-
def obtain_baseline_image(self, item, target_dir):
374+
def obtain_baseline_images(self, item, target_dir):
374375
"""
375-
Copy the baseline image to our working directory.
376+
Copy the baseline image(s) to our working directory.
376377
377378
If the image is remote it is downloaded, if it is local it is copied to
378379
ensure it is kept in the event of a test failure.
379380
"""
381+
compare = self.get_compare(item)
382+
multi = compare.kwargs.get('multi', False)
380383
filename = self.generate_filename(item)
381384
baseline_dir = self.get_baseline_directory(item)
382385
baseline_remote = (isinstance(baseline_dir, str) and # noqa
383386
baseline_dir.startswith(('http://', 'https://')))
384387
if baseline_remote:
388+
if multi:
389+
pytest.fail('Multi-baseline testing only works with local baselines.',
390+
pytrace=False)
385391
# baseline_dir can be a list of URLs when remote, so we have to
386392
# pass base and filename to download
387-
baseline_image = self._download_file(baseline_dir, filename)
393+
baseline_images = [self._download_file(baseline_dir, filename)]
394+
elif not multi:
395+
baseline_images = [(baseline_dir / filename).absolute()]
388396
else:
389-
baseline_image = (baseline_dir / filename).absolute()
397+
dirname, ext = os.path.splitext(filename)
398+
baseline_images = glob.glob(
399+
os.path.join(baseline_dir.absolute(), dirname, '**', '*' + ext),
400+
recursive=True)
401+
402+
return baseline_images
403+
404+
def obtain_baseline_image(self, item, target_dir):
405+
"""
406+
Backwards-Compatible wrapper for obtain_baseline_images.
390407
391-
return baseline_image
408+
Always returns the first found baseline image.
409+
"""
410+
return self.obtain_baseline_images(item, target_dir)[0]
392411

393412
def generate_baseline_image(self, item, fig):
394413
"""
395414
Generate reference figures.
396415
"""
397416
compare = self.get_compare(item)
398417
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
418+
multi = compare.kwargs.get('multi', False)
399419

400420
if not os.path.exists(self.generate_dir):
401421
os.makedirs(self.generate_dir)
402422

403423
baseline_filename = self.generate_filename(item)
404424
baseline_path = (self.generate_dir / baseline_filename).absolute()
405-
fig.savefig(str(baseline_path), **savefig_kwargs)
425+
if multi:
426+
raw_name, ext = os.path.splitext(str(baseline_path))
427+
if not os.path.exists(raw_name):
428+
os.makedirs(raw_name)
429+
baseline_path = os.path.join(raw_name, "generated" + ext)
406430

431+
fig.savefig(str(baseline_path), **savefig_kwargs)
407432
close_mpl_figure(fig)
408433

409434
return baseline_path
@@ -440,13 +465,14 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
440465
tolerance = compare.kwargs.get('tolerance', 2)
441466
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
442467

443-
baseline_image_ref = self.obtain_baseline_image(item, result_dir)
468+
baseline_image_refs = self.obtain_baseline_images(item, result_dir)
469+
baseline_image_refs = [p for p in baseline_image_refs if os.path.exists(p)]
444470

445471
test_image = (result_dir / "result.png").absolute()
446472
fig.savefig(str(test_image), **savefig_kwargs)
447473
summary['result_image'] = test_image.relative_to(self.results_dir).as_posix()
448474

449-
if not os.path.exists(baseline_image_ref):
475+
if len(baseline_image_refs) == 0:
450476
summary['status'] = 'failed'
451477
summary['image_status'] = 'missing'
452478
error_message = ("Image file not found for comparison test in: \n\t"
@@ -457,49 +483,66 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
457483
summary['status_msg'] = error_message
458484
return error_message
459485

460-
# setuptools may put the baseline images in non-accessible places,
461-
# copy to our tmpdir to be sure to keep them in case of failure
462-
baseline_image = (result_dir / "baseline.png").absolute()
463-
shutil.copyfile(baseline_image_ref, baseline_image)
464-
summary['baseline_image'] = baseline_image.relative_to(self.results_dir).as_posix()
465-
466-
# Compare image size ourselves since the Matplotlib
467-
# exception is a bit cryptic in this case and doesn't show
468-
# the filenames
469-
expected_shape = imread(str(baseline_image)).shape[:2]
470-
actual_shape = imread(str(test_image)).shape[:2]
471-
if expected_shape != actual_shape:
472-
summary['status'] = 'failed'
473-
summary['image_status'] = 'diff'
474-
error_message = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image,
475-
expected_shape=expected_shape,
476-
actual_path=test_image,
477-
actual_shape=actual_shape)
478-
summary['status_msg'] = error_message
479-
return error_message
480-
481-
results = compare_images(str(baseline_image), str(test_image), tol=tolerance, in_decorator=True)
482-
summary['tolerance'] = tolerance
483-
if results is None:
484-
summary['status'] = 'passed'
485-
summary['image_status'] = 'match'
486-
summary['status_msg'] = 'Image comparison passed.'
487-
return None
488-
else:
489-
summary['status'] = 'failed'
490-
summary['image_status'] = 'diff'
491-
summary['rms'] = results['rms']
492-
diff_image = (result_dir / 'result-failed-diff.png').absolute()
493-
summary['diff_image'] = diff_image.relative_to(self.results_dir).as_posix()
494-
template = ['Error: Image files did not match.',
495-
'RMS Value: {rms}',
496-
'Expected: \n {expected}',
497-
'Actual: \n {actual}',
498-
'Difference:\n {diff}',
499-
'Tolerance: \n {tol}', ]
500-
error_message = '\n '.join([line.format(**results) for line in template])
501-
summary['status_msg'] = error_message
502-
return error_message
486+
cur_summ = {}
487+
best_rms = float('inf')
488+
all_msgs = ''
489+
i = -1
490+
491+
for baseline_image_ref in baseline_image_refs:
492+
# setuptools may put the baseline images in non-accessible places,
493+
# copy to our tmpdir to be sure to keep them in case of failure
494+
i += 1
495+
baseline_file = f"baseline-{i}.png" if i else "baseline.png"
496+
baseline_image = (result_dir / baseline_file).absolute()
497+
shutil.copyfile(baseline_image_ref, baseline_image)
498+
cur_summ['baseline_image'] = baseline_image.relative_to(self.results_dir).as_posix()
499+
500+
# Compare image size ourselves since the Matplotlib
501+
# exception is a bit cryptic in this case and doesn't show
502+
# the filenames
503+
expected_shape = imread(str(baseline_image)).shape[:2]
504+
actual_shape = imread(str(test_image)).shape[:2]
505+
if expected_shape != actual_shape:
506+
best_rms = float('-inf')
507+
cur_summ = {}
508+
cur_summ['status'] = 'failed'
509+
cur_summ['image_status'] = 'diff'
510+
error_message = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image,
511+
expected_shape=expected_shape,
512+
actual_path=test_image,
513+
actual_shape=actual_shape)
514+
cur_summ['status_msg'] = error_message
515+
all_msgs += error_message + '\n\n'
516+
continue
517+
518+
results = compare_images(str(baseline_image), str(test_image), tol=tolerance, in_decorator=True)
519+
if results is None:
520+
summary['tolerance'] = tolerance
521+
summary['status'] = 'passed'
522+
summary['image_status'] = 'match'
523+
summary['status_msg'] = 'Image comparison passed.'
524+
return None
525+
else:
526+
template = ['Error: Image files did not match.',
527+
'RMS Value: {rms}',
528+
'Expected: \n {expected}',
529+
'Actual: \n {actual}',
530+
'Difference:\n {diff}',
531+
'Tolerance: \n {tol}', ]
532+
error_message = '\n '.join([line.format(**results) for line in template])
533+
all_msgs += error_message + '\n\n'
534+
if results['rms'] < best_rms:
535+
best_rms = results['rms']
536+
cur_summ = {}
537+
cur_summ['status'] = 'failed'
538+
cur_summ['image_status'] = 'diff'
539+
cur_summ['rms'] = results['rms']
540+
diff_image = (result_dir / 'result-failed-diff.png').absolute()
541+
cur_summ['diff_image'] = diff_image.relative_to(self.results_dir).as_posix()
542+
cur_summ['status_msg'] = error_message
543+
544+
summary.update(cur_summ)
545+
return all_msgs.strip()
503546

504547
def load_hash_library(self, library_path):
505548
with open(str(library_path)) as fp:

0 commit comments

Comments
 (0)