Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Update diff_csv to process commented lines #1713

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions tests/junit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from itertools import product
import sys
import time
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Callable

sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8
Expand Down Expand Up @@ -106,6 +106,15 @@ def cgns_tol(self):
def cgns_tol(self, val):
self._cgns_tol = val

@property
def diff_csv_kwargs(self):
"""Keyword arguments to be passed to diff_csv()"""
return getattr(self, '_diff_csv_kwargs', {})

@diff_csv_kwargs.setter
def diff_csv_kwargs(self, val):
self._diff_csv_kwargs = val

def post_test_hook(self, test: str, spec: TestSpec) -> None:
"""Function callback ran after each test case

Expand Down Expand Up @@ -262,14 +271,17 @@ def get_test_args(source_file: Path) -> List[TestSpec]:
if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])]


def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str:
def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2,
comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
"""Compare CSV results against an expected CSV file with tolerances

Args:
test_csv (Path): Path to output CSV results
true_csv (Path): Path to expected CSV results
zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10.
rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2.
comment_str (str, optional): String to denoting commented line
comment_func (Callable, optional): Function to determine if test and true line are different

Returns:
str: Diff output between result and expected CSVs
Expand All @@ -281,15 +293,34 @@ def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: f
return f'No lines found in test output {test_csv}'
if len(true_lines) == 0:
return f'No lines found in test source {true_csv}'
if len(test_lines) != len(true_lines):
return f'Number of lines in {test_csv} and {true_csv} do not match'

# Process commented lines
uncommented_lines: List[int] = []
for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)):
if test_line[0] == comment_str and true_line[0] == comment_str:
if comment_func:
output = comment_func(test_line, true_line)
if output:
return output
elif test_line[0] == comment_str and true_line[0] != comment_str:
return f'Commented line found in {test_csv} at line {n} but not in {true_csv}'
elif test_line[0] != comment_str and true_line[0] == comment_str:
return f'Commented line found in {true_csv} at line {n} but not in {test_csv}'
else:
uncommented_lines.append(n)

# Remove commented lines
test_lines = [test_lines[line] for line in uncommented_lines]
true_lines = [true_lines[line] for line in uncommented_lines]

test_reader: csv.DictReader = csv.DictReader(test_lines)
true_reader: csv.DictReader = csv.DictReader(true_lines)
if test_reader.fieldnames != true_reader.fieldnames:
return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
tofile='found CSV columns', fromfile='expected CSV columns'))

if len(test_lines) != len(true_lines):
return f'Number of lines in {test_csv} and {true_csv} do not match'
diff_lines: List[str] = list()
for test_line, true_line in zip(test_reader, true_reader):
for key in test_reader.fieldnames:
Expand Down Expand Up @@ -435,7 +466,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str,
ref_csvs: List[Path] = []
output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg]
if output_files:
ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) for file in output_files]
ref_cgns: List[Path] = []
output_files = [arg for arg in run_args if 'cgns:' in arg]
if output_files:
Expand Down Expand Up @@ -484,7 +515,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str,
elif not (Path.cwd() / csv_name).is_file():
test_case.add_failure_info('csv', output=f'{csv_name} not found')
else:
diff: str = diff_csv(Path.cwd() / csv_name, ref_csv)
diff: str = diff_csv(Path.cwd() / csv_name, ref_csv, **suite_spec.diff_csv_kwargs)
if diff:
test_case.add_failure_info('csv', output=diff)
else:
Expand Down