Skip to content

Commit

Permalink
added rmsd tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Mar 18, 2024
1 parent f340560 commit e1d2a91
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 6 deletions.
25 changes: 19 additions & 6 deletions mdagent/tools/base_tools/analysis_tools/rmsd_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,25 @@ def __init__(self, path_registry, pdb, traj, ref=None, ref_traj=None):
self.path_registry = path_registry
self.pdb_file = self.path_registry.get_mapped_path(pdb)
self.trajectory = self.path_registry.get_mapped_path(traj)
self.pdb_name = os.path.splitext(os.path.basename(self.pdb_file))[0]
self.ref_file = self.path_registry.get_mapped_path(ref)
self.ref_trajectory = self.path_registry.get_mapped_path(ref_traj)
if self.ref_file:

# check for missing paths
if self.pdb_file == "Name not found in path registry.":
# set that file to None
self.pdb_file = None
self.pdb_name = None
else:
self.pdb_name = os.path.splitext(os.path.basename(self.pdb_file))[0]
if self.trajectory == "Name not found in path registry.":
self.trajectory = None
if self.ref_file == "Name not found in path registry." or self.ref_file is None:
self.ref_file = None
self.ref_name = None
else:
self.ref_name = os.path.splitext(os.path.basename(self.ref_file))[0]
if self.ref_trajectory == "Name not found in path registry.":
self.ref_trajectory = None

def calculate_rmsd(
self,
Expand Down Expand Up @@ -119,8 +133,8 @@ def compute_rmsd(self, selection="backbone", plot=True):
plt.ylabel("RMSD ($\AA$)")
plt.title("Time-Dependent RMSD")
plt.legend()
plt.show()
plt.savefig(f"{self.filename}.png")
# plt.show()
# plt.close() # if you don't want to show the plot in notebooks
message += f"Plotted RMSD over time. Saved to {self.filename}.png.\n"
self.path_registry.map_path(
Expand Down Expand Up @@ -165,7 +179,7 @@ def compute_2d_rmsd(self, selection="backbone", plot_heatmap=True):
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.colorbar(label=r"RMSD ($\AA$)")
plt.show()
# plt.show()
plt.savefig(f"{self.filename}.png")
message += f"Plotted pairwise RMSD matrix. Saved to {self.filename}.png.\n"
self.path_registry.map_path(
Expand All @@ -187,7 +201,6 @@ def compute_rmsf(self, selection="backbone", plot=True):
R = rms.RMSF(atoms).run()
rmsf = R.results.rmsf

# Save to a text file
rmsf_data = np.column_stack((atoms.resids, rmsf))
np.savetxt(
f"{self.filename}.csv",
Expand All @@ -209,8 +222,8 @@ def compute_rmsf(self, selection="backbone", plot=True):
plt.ylabel("RMSF ($\AA$)")
plt.title("Root Mean Square Fluctuation")
plt.legend()
plt.show()
plt.savefig(f"{self.filename}.png")
# plt.show()
message += f"Plotted RMSF. Saved to {self.filename}.png.\n"
self.path_registry.map_path(
f"{self.filename}.png", f"{self.filename}.png", message
Expand Down
116 changes: 116 additions & 0 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from unittest.mock import MagicMock, mock_open, patch

import MDAnalysis as mda
import numpy as np
import pytest

from mdagent.tools.base_tools import VisFunctions
from mdagent.tools.base_tools.analysis_tools.plot_tools import PlottingTools
from mdagent.tools.base_tools.analysis_tools.ppi_tools import ppi_distance
from mdagent.tools.base_tools.analysis_tools.rmsd_tools import RMSDFunctions
from mdagent.utils import PathRegistry


Expand All @@ -26,6 +28,13 @@ def vis_fxns(get_registry):
return VisFunctions(get_registry)


@pytest.fixture
def rmsd_functions():
path_registry_mock = MagicMock() # to avoid file system dependence
path_registry_mock.get_mapped_path.side_effect = lambda x: x # simply echo input
return RMSDFunctions(path_registry_mock, "pdb_file.pdb", "trajectory.dcd")


@pytest.fixture
def path_to_cif():
# Save original working directory
Expand Down Expand Up @@ -157,7 +166,114 @@ def mock_mda_universe():
yield mock_universe


@pytest.fixture
def mock_savetxt():
with patch(
"mdagent.tools.base_tools.analysis_tools.rmsd_tools.np.savetxt"
) as mock_savetxt:
yield mock_savetxt


@pytest.fixture
def mock_rmsd_run():
mock_rmsd_results = MagicMock()
mock_rmsd_results.rmsd = np.array([[0, 0.0, 0.1], [1, 10.0, 0.2]])
with patch(
"mdagent.tools.base_tools.analysis_tools.ppi_tools.mda.analysis.rms.RMSD",
return_value=MagicMock(results=mock_rmsd_results),
) as mock:
yield mock


@pytest.fixture
def mock_plt_savefig():
with patch("mdagent.tools.base_tools.analysis_tools.rmsd_tools.plt.savefig") as plt:
yield plt


def test_ppi_distance(mock_mda_universe):
file_path = "dummy_path.pdb"
avg_dist = ppi_distance(file_path)
assert avg_dist > 0, "Expected a positive average distance"


def test_calculate_rmsd_with_invalid_rmsd_type(rmsd_functions):
with pytest.raises(ValueError):
rmsd_functions.calculate_rmsd(rmsd_type="invalid_type")


def test_calculate_rmsd_without_ref_file(rmsd_functions):
with patch.object(
rmsd_functions, "compute_rmsd", return_value="RMSD calculated"
) as mock_method:
result = rmsd_functions.calculate_rmsd()
mock_method.assert_called_once()
assert "RMSD calculated" in result


def test_compute_rmsd_2sets(mock_mda_universe, rmsd_functions):
with patch(
"mdagent.tools.base_tools.analysis_tools.ppi_tools.mda.analysis.rms.rmsd",
return_value=0.5,
) as mock_rmsd:
result = rmsd_functions.compute_rmsd_2sets(selection="backbone")
assert "0.5" in result, "RMSD value should be present in the result string"
mock_mda_universe.assert_called()
mock_rmsd.assert_called()


def test_compute_rmsd(mock_mda_universe, mock_rmsd_run, mock_savetxt, rmsd_functions):
rmsd_functions.filename = "test_rmsd" # avoid filesystem dependence
message = rmsd_functions.compute_rmsd(selection="backbone", plot=True)

mock_mda_universe.assert_called()
mock_rmsd_run.assert_called()
mock_savetxt.assert_called()
args, kwargs = mock_savetxt.call_args
assert "test_rmsd.csv" in args, "Expected np.savetxt to save to correct file"
assert "Average RMSD is 0.15" in message, "Expected correct average RMSD in message"
assert "Final RMSD is 0.2" in message, "Expected correct final RMSD in message"
assert "Saved to test_rmsd.csv" in message, "Expected correct save file message"


@pytest.mark.parametrize("plot_enabled", [True, False])
def test_compute_rmsd_plotting(
plot_enabled, mock_mda_universe, mock_plt_savefig, rmsd_functions
):
rmsd_functions.filename = "test_rmsd"
message = rmsd_functions.compute_rmsd(selection="backbone", plot=plot_enabled)
if plot_enabled:
mock_plt_savefig.assert_called_with("test_rmsd.png")
assert (
"Plotted RMSD over time. Saved to test_rmsd.png." in message
), "Expected correct plotting message"
else:
mock_plt_savefig.assert_not_called()


def test_compute_2d_rmsd(mock_mda_universe, mock_savetxt, rmsd_functions):
# TODO: test plotting as well
rmsd_functions.filename = "test_pairwise_rmsd"
patch_path = (
"mdagent.tools.base_tools.analysis_tools.ppi_tools.mda.analysis."
"diffusionmap.DistanceMatrix.run"
)
with patch(patch_path) as mock_distance_matrix_run:
result = rmsd_functions.compute_2d_rmsd(
selection="backbone", plot_heatmap=False
)
mock_mda_universe.assert_called()
mock_distance_matrix_run.assert_called()
mock_savetxt.assert_called()
assert "Saved pairwise RMSD matrix" in result


# def test_calculate_rmsd(rmsd_functions):
# with patch(RMSDFunctions, 'compute_rmsd_2sets') as mock_compute_2sets, \
# patch(RMSDFunctions, 'compute_rmsd') as mock_compute_rmsd, \
# patch(RMSDFunctions, 'compute_2d_rmsd') as mock_compute_2d_rmsd, \
# patch(RMSDFunctions, 'compute_rmsf') as mock_compute_rmsf:

# rmsd_functions.ref_file = 'ref.pdb'
# rmsd_functions.calculate_rmsd(rmsd_type='rmsd')
# mock_compute_2sets.assert_called_once()

0 comments on commit e1d2a91

Please sign in to comment.