diff --git a/tests/junit_common.py b/tests/junit_common.py index 255f8218e2..94a888440f 100644 --- a/tests/junit_common.py +++ b/tests/junit_common.py @@ -13,7 +13,7 @@ from itertools import product import sys import time -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Callable sys.path.insert(0, str(Path(__file__).parent / "junit-xml")) from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8 @@ -106,6 +106,15 @@ def cgns_tol(self): def cgns_tol(self, val): self._cgns_tol = val + @property + def diff_csv_kwargs(self): + """Keyword arguments to be passed to diff_csv()""" + return getattr(self, '_diff_csv_kwargs', {}) + + @diff_csv_kwargs.setter + def diff_csv_kwargs(self, val): + self._diff_csv_kwargs = val + def post_test_hook(self, test: str, spec: TestSpec) -> None: """Function callback ran after each test case @@ -262,7 +271,8 @@ def get_test_args(source_file: Path) -> List[TestSpec]: if line.startswith(f'{comment_str}TESTARGS')] or [TestSpec('', args=['{ceed_resource}'])] -def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2) -> str: +def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: float = 1e-2, + comment_str: str = '#', comment_func: Optional[Callable[[str, str], Optional[str]]] = None) -> str: """Compare CSV results against an expected CSV file with tolerances Args: @@ -270,6 +280,8 @@ def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: f true_csv (Path): Path to expected CSV results zero_tol (float, optional): Tolerance below which values are considered to be zero. Defaults to 3e-10. rel_tol (float, optional): Relative tolerance for comparing non-zero values. Defaults to 1e-2. + comment_str (str, optional): String to denoting commented line + comment_func (Callable, optional): Function to determine if test and true line are different Returns: str: Diff output between result and expected CSVs @@ -281,6 +293,27 @@ def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: f return f'No lines found in test output {test_csv}' if len(true_lines) == 0: return f'No lines found in test source {true_csv}' + if len(test_lines) != len(true_lines): + return f'Number of lines in {test_csv} and {true_csv} do not match' + + # Process commented lines + uncommented_lines: List[int] = [] + for n, (test_line, true_line) in enumerate(zip(test_lines, true_lines)): + if test_line[0] == comment_str and true_line[0] == comment_str: + if comment_func: + output = comment_func(test_line, true_line) + if output: + return output + elif test_line[0] == comment_str and true_line[0] != comment_str: + return f'Commented line found in {test_csv} at line {n} but not in {true_csv}' + elif test_line[0] != comment_str and true_line[0] == comment_str: + return f'Commented line found in {true_csv} at line {n} but not in {test_csv}' + else: + uncommented_lines.append(n) + + # Remove commented lines + test_lines = [test_lines[line] for line in uncommented_lines] + true_lines = [true_lines[line] for line in uncommented_lines] test_reader: csv.DictReader = csv.DictReader(test_lines) true_reader: csv.DictReader = csv.DictReader(true_lines) @@ -288,8 +321,6 @@ def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: f return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'], tofile='found CSV columns', fromfile='expected CSV columns')) - if len(test_lines) != len(true_lines): - return f'Number of lines in {test_csv} and {true_csv} do not match' diff_lines: List[str] = list() for test_line, true_line in zip(test_reader, true_reader): for key in test_reader.fieldnames: @@ -435,7 +466,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str, ref_csvs: List[Path] = [] output_files: List[str] = [arg for arg in run_args if 'ascii:' in arg] if output_files: - ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files] + ref_csvs = [suite_spec.get_output_path(test, file.split(':')[1]) for file in output_files] ref_cgns: List[Path] = [] output_files = [arg for arg in run_args if 'cgns:' in arg] if output_files: @@ -484,7 +515,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str, elif not (Path.cwd() / csv_name).is_file(): test_case.add_failure_info('csv', output=f'{csv_name} not found') else: - diff: str = diff_csv(Path.cwd() / csv_name, ref_csv) + diff: str = diff_csv(Path.cwd() / csv_name, ref_csv, **suite_spec.diff_csv_kwargs) if diff: test_case.add_failure_info('csv', output=diff) else: