Skip to content

Commit abf8bcf

Browse files
committed
rearranged tests to subfiles
1 parent ddf0614 commit abf8bcf

File tree

6 files changed

+480
-463
lines changed

6 files changed

+480
-463
lines changed

tests/test_analysis_tools.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
from unittest.mock import MagicMock, mock_open, patch
3+
4+
import pytest
5+
6+
from mdagent.tools.base_tools import VisFunctions
7+
from mdagent.tools.base_tools.analysis_tools.plot_tools import PlottingTools
8+
from mdagent.utils import PathRegistry
9+
10+
11+
@pytest.fixture
12+
def get_registry():
13+
return PathRegistry()
14+
15+
16+
@pytest.fixture
17+
def plotting_tools(get_registry):
18+
return PlottingTools(get_registry)
19+
20+
21+
@pytest.fixture
22+
def vis_fxns(get_registry):
23+
return VisFunctions(get_registry)
24+
25+
26+
@pytest.fixture
27+
def path_to_cif():
28+
# Save original working directory
29+
original_cwd = os.getcwd()
30+
31+
# Change current working directory to the directory where the CIF file is located
32+
tests_dir = os.path.dirname(os.path.abspath(__file__))
33+
os.chdir(tests_dir)
34+
35+
# Yield the filename only
36+
filename_only = "3pqr.cif"
37+
yield filename_only
38+
39+
# Restore original working directory after the test is done
40+
os.chdir(original_cwd)
41+
42+
43+
def test_process_csv(plotting_tools):
44+
mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25"
45+
mock_reader = MagicMock()
46+
mock_reader.fieldnames = ["Time", "Value1", "Value2"]
47+
mock_reader.__iter__.return_value = iter(
48+
[
49+
{"Time": "1", "Value1": "10", "Value2": "20"},
50+
{"Time": "2", "Value1": "15", "Value2": "25"},
51+
]
52+
)
53+
plotting_tools.file_path = "mock_file.csv"
54+
plotting_tools.file_name = "mock_file.csv"
55+
with patch("builtins.open", mock_open(read_data=mock_csv_content)):
56+
with patch("csv.DictReader", return_value=mock_reader):
57+
plotting_tools.process_csv()
58+
59+
assert plotting_tools.headers == ["Time", "Value1", "Value2"]
60+
assert len(plotting_tools.matched_headers) == 1
61+
assert plotting_tools.matched_headers[0][1] == "Time"
62+
assert len(plotting_tools.data) == 2
63+
assert (
64+
plotting_tools.data[0]["Time"] == "1"
65+
and plotting_tools.data[0]["Value1"] == "10"
66+
)
67+
68+
69+
def test_plot_data(plotting_tools):
70+
# Test successful plot generation
71+
data_success = [
72+
{"Time": "1", "Value1": "10", "Value2": "20"},
73+
{"Time": "2", "Value1": "15", "Value2": "25"},
74+
]
75+
headers = ["Time", "Value1", "Value2"]
76+
matched_headers = [(0, "Time")]
77+
78+
with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch(
79+
"matplotlib.pyplot.xlabel"
80+
), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch(
81+
"matplotlib.pyplot.savefig"
82+
), patch(
83+
"matplotlib.pyplot.close"
84+
):
85+
plotting_tools.data = data_success
86+
plotting_tools.headers = headers
87+
plotting_tools.matched_headers = matched_headers
88+
created_plots = plotting_tools.plot_data()
89+
assert "time_vs_value1.png" in created_plots
90+
assert "time_vs_value2.png" in created_plots
91+
92+
# Test failure due to non-numeric data
93+
data_failure = [
94+
{"Time": "1", "Value1": "A", "Value2": "B"},
95+
{"Time": "2", "Value1": "C", "Value2": "D"},
96+
]
97+
98+
plotting_tools.data = data_failure
99+
plotting_tools.headers = headers
100+
plotting_tools.matched_headers = matched_headers
101+
102+
with pytest.raises(Exception) as excinfo:
103+
plotting_tools.plot_data()
104+
assert "All plots failed due to non-numeric data." in str(excinfo.value)
105+
106+
107+
@pytest.mark.skip(reason="molrender is not pip installable")
108+
def test_run_molrender(path_to_cif, vis_fxns):
109+
result = vis_fxns.run_molrender(path_to_cif)
110+
assert result == "Visualization created"
111+
112+
113+
def test_find_png(vis_fxns):
114+
vis_fxns.starting_files = os.listdir(".")
115+
test_file = "test_image.png"
116+
with open(test_file, "w") as f:
117+
f.write("")
118+
png_files = vis_fxns._find_png()
119+
assert test_file in png_files
120+
121+
os.remove(test_file)
122+
123+
124+
def test_create_notebook(path_to_cif, vis_fxns):
125+
result = vis_fxns.create_notebook(path_to_cif)
126+
path_to_notebook = path_to_cif.split(".")[0] + "_vis.ipynb"
127+
os.remove(path_to_notebook)
128+
assert result == "Visualization Complete"

0 commit comments

Comments
 (0)