Skip to content

Commit

Permalink
test: Update diff_csv to process commented lines
Browse files Browse the repository at this point in the history
  • Loading branch information
jrwrigh committed Dec 2, 2024
1 parent be8d6f5 commit 12235d7
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions tests/junit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from itertools import product
import sys
import time
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Callable

sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8
Expand Down Expand Up @@ -106,6 +106,15 @@ def cgns_tol(self):
def cgns_tol(self, val):
self._cgns_tol = val

@property
def diff_csv_kwargs(self):
"""Keyword arguments to be passed to diff_csv()"""
return getattr(self, '_diff_csv_kwargs', {})

@diff_csv_kwargs.setter
def diff_csv_kwargs(self, val):
self._diff_csv_kwargs = val

def post_test_hook(self, test: str, spec: TestSpec) -> None:
"""Function callback ran after each test case
Expand Down Expand Up @@ -262,14 +271,17 @@ def get_test_args(source_file: Path) -> List[TestSpec]:
if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]


def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2,
comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
"""Compare CSV results against an expected CSV file with tolerances
Args:
test_csv (Path): Path to output CSV results
true_csv (Path): Path to expected CSV results
zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
comment_str (str, optional): String to denoting commented line
comment_func (Callable, optional): Function to determine if test and true line are different
Returns:
str: Diff output between result and expected CSVs
Expand All @@ -281,15 +293,34 @@ def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: f
return f'No lines found in test output {test_csv}'
if len(true_lines) == 0:
return f'No lines found in test source {true_csv}'
if len(test_lines) != len(true_lines):
return f'Number of lines in {test_csv} and {true_csv} do not match'

# Process commented lines
uncommented_lines: List[int] = []
for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)):
if test_line[0] == comment_str and true_line[0] == comment_str:
if comment_func:
output = comment_func(test_line, true_line)
if output:
return output
elif test_line[0] == comment_str and true_line[0] != comment_str:
return f'Commented line found in {test_csv} at line {n} but not in {true_csv}'
elif test_line[0] != comment_str and true_line[0] == comment_str:
return f'Commented line found in {true_csv} at line {n} but not in {test_csv}'
else:
uncommented_lines.append(n)

# Remove commented lines
test_lines = [test_lines[line] for line in uncommented_lines]
true_lines = [true_lines[line] for line in uncommented_lines]

test_reader: csv.DictReader = csv.DictReader(test_lines)
true_reader: csv.DictReader = csv.DictReader(true_lines)
if test_reader.fieldnames != true_reader.fieldnames:
return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
tofile='found CSV columns', fromfile='expected CSV columns'))

if len(test_lines) != len(true_lines):
return f'Number of lines in {test_csv} and {true_csv} do not match'
diff_lines: List[str] = list()
for test_line, true_line in zip(test_reader, true_reader):
for key in test_reader.fieldnames:
Expand Down Expand Up @@ -435,7 +466,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str,
ref_csvs: List[Path] = []
output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
if output_files:
ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) for file in output_files]
ref_cgns: List[Path] = []
output_files = [arg for arg in run_args if 'cgns:' in arg]
if output_files:
Expand Down Expand Up @@ -484,7 +515,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str,
elif not (Path.cwd() / csv_name).is_file():
test_case.add_failure_info('csv', output=f'{csv_name} not found')
else:
diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
diff: str = diff_csv(Path.cwd() / csv_name, ref_csv, **suite_spec.diff_csv_kwargs)
if diff:
test_case.add_failure_info('csv', output=diff)
else:
Expand Down

0 comments on commit 12235d7

Please sign in to comment.