Skip to content

Commit 10c8e14

Browse files
committed
refactor pytest
1 parent 7e48bc0 commit 10c8e14

File tree

1 file changed

+54
-55
lines changed

1 file changed

+54
-55
lines changed

tests/test_NMF_analysis_code.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,61 @@
1-
import os
2-
import tempfile
1+
import shutil
2+
from os.path import exists, join
33
from pathlib import Path
4-
from shutil import rmtree
54

65
import pytest
76

87
from diffpy.nmf_mapping.main import main
98

109
dir = Path(__file__).parent.resolve()
11-
12-
data_dir = os.path.join(dir, "data/synthetic_r_vs_gr")
13-
14-
test_map = [
15-
([data_dir, "--xrange", "5,10"], "output_1", "Number of components: 3\n"),
16-
([data_dir], "output_2", "Number of components: 3\n"),
17-
([data_dir, "--xrange", "5,10", "12,15"], "output_3", "Number of components: 3\n"),
18-
]
19-
20-
21-
@pytest.fixture(scope="session")
22-
def temp_dir():
23-
"""A test fixture that creates and destroys tes outputs in a temporary
24-
directory.
25-
This will yield the path to the directory.
26-
"""
27-
cwd = os.getcwd()
28-
name = "outputs"
29-
temp_dir = Path(tempfile.gettempdir())
30-
repo = os.path.join(temp_dir, name)
31-
if os.path.exists(repo):
32-
rmtree(repo)
33-
os.chdir(temp_dir)
34-
os.mkdir(name)
35-
os.chdir(cwd)
36-
yield repo
37-
os.chdir(cwd)
38-
rmtree(repo)
39-
40-
41-
@pytest.mark.parametrize("tm", test_map)
42-
def test_nmf_mapping_code(tm, temp_dir, capsys):
43-
data_dir = tm[0]
44-
working_dir = Path(temp_dir)
45-
os.chdir(working_dir)
46-
main(args=data_dir)
47-
out, err = capsys.readouterr()
48-
assert out == tm[2]
49-
results_dir = os.path.join(working_dir, "nmf_result")
50-
os.chdir(results_dir)
51-
expected_base = os.path.join(os.path.dirname(__file__), "output")
52-
test_specific_dir = os.path.join(expected_base, tm[1])
53-
for root, dirs, files in os.walk("."):
54-
for file in files:
55-
if file in os.listdir(test_specific_dir):
56-
fn1 = os.path.join(results_dir, file)
57-
with open(fn1, "r") as f:
58-
actual = f.read()
59-
fn2 = os.path.join(test_specific_dir, file)
60-
with open(fn2, "r") as f:
61-
expected = f.read()
62-
assert expected == actual
10+
data_dir = join(dir, "data/synthetic_r_vs_gr")
11+
12+
13+
@pytest.mark.parametrize(
14+
"args, output_dir",
15+
[
16+
(["tests/data/synthetic_r_vs_gr", "--xrange", "5,10", "--show", "false"], "output_1"),
17+
(["tests/data/synthetic_r_vs_gr", "--show", "false"], "output_2"),
18+
(["tests/data/synthetic_r_vs_gr", "--xrange", "5,10", "12,15", "--show", "false"], "output_3"),
19+
],
20+
)
21+
def test_nmf_mapping_code(args, output_dir, tmpdir):
22+
23+
# Save the result in ("nmf_result") at the top project level (default behavior)
24+
main(args=args)
25+
26+
# Define the copied results directory in tmpdir
27+
tmp_results_dir = join(tmpdir, "nmf_result")
28+
expected_output_dir = join("tests/output", output_dir)
29+
30+
# Copy the output to tmpdir
31+
shutil.copytree("nmf_result", tmp_results_dir)
32+
33+
# Remove the nmf_result folder from the top project level
34+
shutil.rmtree("nmf_result")
35+
36+
# Define the specific JSON files to check
37+
json_files_to_check = [
38+
"component_index_vs_pratio_col.json",
39+
"component_index_vs_RE_value.json",
40+
"x_index_vs_y_col_components.json",
41+
]
42+
43+
# Compare each specified .json file
44+
for json_file in json_files_to_check:
45+
# Define paths for actual and expected .json files
46+
actual_file_path = join(tmp_results_dir, json_file)
47+
expected_file_path = join(expected_output_dir, json_file)
48+
49+
# Ensure the file exists in both locations
50+
assert exists(actual_file_path)
51+
assert exists(expected_file_path)
52+
53+
# Read and compare file contents
54+
with open(actual_file_path, "r") as actual_file:
55+
actual = actual_file.read()
56+
57+
with open(expected_file_path, "r") as expected_file:
58+
expected = expected_file.read()
59+
60+
# Assert that the contents match
61+
assert actual == expected

0 commit comments

Comments
 (0)