|
1 |
| -import os |
2 |
| -import tempfile |
| 1 | +import shutil |
| 2 | +from os.path import exists, join |
3 | 3 | from pathlib import Path
|
4 |
| -from shutil import rmtree |
5 | 4 |
|
6 | 5 | import pytest
|
7 | 6 |
|
8 | 7 | from diffpy.nmf_mapping.main import main
|
9 | 8 |
|
10 | 9 | dir = Path(__file__).parent.resolve()
|
11 |
| - |
12 |
| -data_dir = os.path.join(dir, "data/synthetic_r_vs_gr") |
13 |
| - |
14 |
| -test_map = [ |
15 |
| - ([data_dir, "--xrange", "5,10"], "output_1", "Number of components: 3\n"), |
16 |
| - ([data_dir], "output_2", "Number of components: 3\n"), |
17 |
| - ([data_dir, "--xrange", "5,10", "12,15"], "output_3", "Number of components: 3\n"), |
18 |
| -] |
19 |
| - |
20 |
| - |
21 |
| -@pytest.fixture(scope="session") |
22 |
| -def temp_dir(): |
23 |
| - """A test fixture that creates and destroys tes outputs in a temporary |
24 |
| - directory. |
25 |
| - This will yield the path to the directory. |
26 |
| - """ |
27 |
| - cwd = os.getcwd() |
28 |
| - name = "outputs" |
29 |
| - temp_dir = Path(tempfile.gettempdir()) |
30 |
| - repo = os.path.join(temp_dir, name) |
31 |
| - if os.path.exists(repo): |
32 |
| - rmtree(repo) |
33 |
| - os.chdir(temp_dir) |
34 |
| - os.mkdir(name) |
35 |
| - os.chdir(cwd) |
36 |
| - yield repo |
37 |
| - os.chdir(cwd) |
38 |
| - rmtree(repo) |
39 |
| - |
40 |
| - |
41 |
| -@pytest.mark.parametrize("tm", test_map) |
42 |
| -def test_nmf_mapping_code(tm, temp_dir, capsys): |
43 |
| - data_dir = tm[0] |
44 |
| - working_dir = Path(temp_dir) |
45 |
| - os.chdir(working_dir) |
46 |
| - main(args=data_dir) |
47 |
| - out, err = capsys.readouterr() |
48 |
| - assert out == tm[2] |
49 |
| - results_dir = os.path.join(working_dir, "nmf_result") |
50 |
| - os.chdir(results_dir) |
51 |
| - expected_base = os.path.join(os.path.dirname(__file__), "output") |
52 |
| - test_specific_dir = os.path.join(expected_base, tm[1]) |
53 |
| - for root, dirs, files in os.walk("."): |
54 |
| - for file in files: |
55 |
| - if file in os.listdir(test_specific_dir): |
56 |
| - fn1 = os.path.join(results_dir, file) |
57 |
| - with open(fn1, "r") as f: |
58 |
| - actual = f.read() |
59 |
| - fn2 = os.path.join(test_specific_dir, file) |
60 |
| - with open(fn2, "r") as f: |
61 |
| - expected = f.read() |
62 |
| - assert expected == actual |
| 10 | +data_dir = join(dir, "data/synthetic_r_vs_gr") |
| 11 | + |
| 12 | + |
| 13 | +@pytest.mark.parametrize( |
| 14 | + "args, output_dir", |
| 15 | + [ |
| 16 | + (["tests/data/synthetic_r_vs_gr", "--xrange", "5,10", "--show", "false"], "output_1"), |
| 17 | + (["tests/data/synthetic_r_vs_gr", "--show", "false"], "output_2"), |
| 18 | + (["tests/data/synthetic_r_vs_gr", "--xrange", "5,10", "12,15", "--show", "false"], "output_3"), |
| 19 | + ], |
| 20 | +) |
| 21 | +def test_nmf_mapping_code(args, output_dir, tmpdir): |
| 22 | + |
| 23 | + # Save the result in ("nmf_result") at the top project level (default behavior) |
| 24 | + main(args=args) |
| 25 | + |
| 26 | + # Define the copied results directory in tmpdir |
| 27 | + tmp_results_dir = join(tmpdir, "nmf_result") |
| 28 | + expected_output_dir = join("tests/output", output_dir) |
| 29 | + |
| 30 | + # Copy the output to tmpdir |
| 31 | + shutil.copytree("nmf_result", tmp_results_dir) |
| 32 | + |
| 33 | + # Remove the nmf_result folder from the top project level |
| 34 | + shutil.rmtree("nmf_result") |
| 35 | + |
| 36 | + # Define the specific JSON files to check |
| 37 | + json_files_to_check = [ |
| 38 | + "component_index_vs_pratio_col.json", |
| 39 | + "component_index_vs_RE_value.json", |
| 40 | + "x_index_vs_y_col_components.json", |
| 41 | + ] |
| 42 | + |
| 43 | + # Compare each specified .json file |
| 44 | + for json_file in json_files_to_check: |
| 45 | + # Define paths for actual and expected .json files |
| 46 | + actual_file_path = join(tmp_results_dir, json_file) |
| 47 | + expected_file_path = join(expected_output_dir, json_file) |
| 48 | + |
| 49 | + # Ensure the file exists in both locations |
| 50 | + assert exists(actual_file_path) |
| 51 | + assert exists(expected_file_path) |
| 52 | + |
| 53 | + # Read and compare file contents |
| 54 | + with open(actual_file_path, "r") as actual_file: |
| 55 | + actual = actual_file.read() |
| 56 | + |
| 57 | + with open(expected_file_path, "r") as expected_file: |
| 58 | + expected = expected_file.read() |
| 59 | + |
| 60 | + # Assert that the contents match |
| 61 | + assert actual == expected |
0 commit comments