Skip to content

Commit

Permalink
added rmsf test
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Mar 18, 2024
1 parent e1d2a91 commit c57cefe
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 28 deletions.
4 changes: 4 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/rmsd_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ def compute_rmsf(self, selection="backbone", plot=True):
atoms = u.select_atoms(selection)
R = rms.RMSF(atoms).run()
rmsf = R.results.rmsf
self.process_rmsf_results(atoms, rmsf, selection=selection, plot=plot)

def process_rmsf_results(self, atoms, rmsf, selection="backbone", plot=True):
print(f"rmsf: {rmsf}")
print("atoms.resids: ", atoms.resids)
rmsf_data = np.column_stack((atoms.resids, rmsf))
np.savetxt(
f"{self.filename}.csv",
Expand Down
126 changes: 98 additions & 28 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,6 @@ def mock_mda_universe():
yield mock_universe


@pytest.fixture
def mock_savetxt():
with patch(
"mdagent.tools.base_tools.analysis_tools.rmsd_tools.np.savetxt"
) as mock_savetxt:
yield mock_savetxt


@pytest.fixture
def mock_rmsd_run():
mock_rmsd_results = MagicMock()
Expand All @@ -185,6 +177,14 @@ def mock_rmsd_run():
yield mock


@pytest.fixture
def mock_savetxt():
with patch(
"mdagent.tools.base_tools.analysis_tools.rmsd_tools.np.savetxt"
) as mock_savetxt:
yield mock_savetxt


@pytest.fixture
def mock_plt_savefig():
with patch("mdagent.tools.base_tools.analysis_tools.rmsd_tools.plt.savefig") as plt:
Expand All @@ -197,18 +197,58 @@ def test_ppi_distance(mock_mda_universe):
assert avg_dist > 0, "Expected a positive average distance"


def test_calculate_rmsd_with_invalid_rmsd_type(rmsd_functions):
with pytest.raises(ValueError):
rmsd_functions.calculate_rmsd(rmsd_type="invalid_type")
def test_calculate_rmsd(rmsd_functions):
# Mock all related compute_* methods in rmsd_functions
with patch.object(
rmsd_functions, "compute_rmsd_2sets"
) as mock_compute_2sets, patch.object(
rmsd_functions, "compute_rmsd"
) as mock_compute_rmsd, patch.object(
rmsd_functions, "compute_2d_rmsd"
) as mock_compute_2d_rmsd, patch.object(
rmsd_functions, "compute_rmsf"
) as mock_compute_rmsf:
# Test rmsd_type="rmsd" with a reference file (call compute_rmsd_2sets)
rmsd_functions.ref_file = "ref.pdb"
rmsd_functions.calculate_rmsd(rmsd_type="rmsd")
mock_compute_2sets.assert_called_once_with(selection="backbone")
mock_compute_rmsd.assert_not_called()
mock_compute_2d_rmsd.assert_not_called()
mock_compute_rmsf.assert_not_called()

mock_compute_2sets.reset_mock()

# Test rmsd_type="rmsd" without a reference file (compute_rmsd should be called)
rmsd_functions.ref_file = None
rmsd_functions.calculate_rmsd(rmsd_type="rmsd")
mock_compute_rmsd.assert_called_once_with(selection="backbone", plot=True)
mock_compute_2sets.assert_not_called()
mock_compute_2d_rmsd.assert_not_called()
mock_compute_rmsf.assert_not_called()

mock_compute_rmsd.reset_mock()

# Test rmsd_type="pairwise_rmsd" (compute_2d_rmsd should be called)
rmsd_functions.calculate_rmsd(rmsd_type="pairwise_rmsd")
mock_compute_2d_rmsd.assert_called_once_with(
selection="backbone", plot_heatmap=True
)
mock_compute_2sets.assert_not_called()
mock_compute_rmsd.assert_not_called()
mock_compute_rmsf.assert_not_called()

mock_compute_2d_rmsd.reset_mock()

def test_calculate_rmsd_without_ref_file(rmsd_functions):
with patch.object(
rmsd_functions, "compute_rmsd", return_value="RMSD calculated"
) as mock_method:
result = rmsd_functions.calculate_rmsd()
mock_method.assert_called_once()
assert "RMSD calculated" in result
# Test rmsd_type="rmsf" (compute_rmsf should be called)
rmsd_functions.calculate_rmsd(rmsd_type="rmsf")
mock_compute_rmsf.assert_called_once_with(selection="backbone", plot=True)
mock_compute_2sets.assert_not_called()
mock_compute_rmsd.assert_not_called()
mock_compute_2d_rmsd.assert_not_called()

# Test for invalid rmsd_type (should raise ValueError)
with pytest.raises(ValueError):
rmsd_functions.calculate_rmsd(rmsd_type="invalid_rmsd_type")


def test_compute_rmsd_2sets(mock_mda_universe, rmsd_functions):
Expand All @@ -223,7 +263,7 @@ def test_compute_rmsd_2sets(mock_mda_universe, rmsd_functions):


def test_compute_rmsd(mock_mda_universe, mock_rmsd_run, mock_savetxt, rmsd_functions):
rmsd_functions.filename = "test_rmsd" # avoid filesystem dependence
rmsd_functions.filename = "test_rmsd"
message = rmsd_functions.compute_rmsd(selection="backbone", plot=True)

mock_mda_universe.assert_called()
Expand Down Expand Up @@ -252,7 +292,6 @@ def test_compute_rmsd_plotting(


def test_compute_2d_rmsd(mock_mda_universe, mock_savetxt, rmsd_functions):
# TODO: test plotting as well
rmsd_functions.filename = "test_pairwise_rmsd"
patch_path = (
"mdagent.tools.base_tools.analysis_tools.ppi_tools.mda.analysis."
Expand All @@ -268,12 +307,43 @@ def test_compute_2d_rmsd(mock_mda_universe, mock_savetxt, rmsd_functions):
assert "Saved pairwise RMSD matrix" in result


# def test_calculate_rmsd(rmsd_functions):
# with patch(RMSDFunctions, 'compute_rmsd_2sets') as mock_compute_2sets, \
# patch(RMSDFunctions, 'compute_rmsd') as mock_compute_rmsd, \
# patch(RMSDFunctions, 'compute_2d_rmsd') as mock_compute_2d_rmsd, \
# patch(RMSDFunctions, 'compute_rmsf') as mock_compute_rmsf:
@pytest.mark.parametrize("plot", [True, False])
def test_process_rmsf_results(
rmsd_functions, tmp_path, mock_plt_savefig, mock_savetxt, plot
):
mock_atoms = MagicMock()
mock_atoms.resids = np.arange(1, 11)
mock_atoms.resnums = np.arange(1, 11)
mock_rmsf_values = np.random.rand(10)
output_csv = tmp_path / "output_rmsf.csv"
output_png = tmp_path / "output_rmsf.png"
rmsd_functions.filename = str(tmp_path / "output_rmsf")
message = rmsd_functions.process_rmsf_results(
mock_atoms, mock_rmsf_values, plot=plot
)
mock_savetxt.assert_called_once()
args, _ = mock_savetxt.call_args
assert args[0] == str(output_csv), "CSV file path passed to np.savetxt don't match."

if plot:
mock_plt_savefig.assert_called_once_with(str(output_png))
assert "Plotted RMSF. Saved to" in message
else:
mock_plt_savefig.assert_not_called()
assert "Saved RMSF data to" in message

# rmsd_functions.ref_file = 'ref.pdb'
# rmsd_functions.calculate_rmsd(rmsd_type='rmsd')
# mock_compute_2sets.assert_called_once()

@pytest.mark.parametrize("plot", [True, False])
def test_compute_rmsf(rmsd_functions, mock_mda_universe, plot):
with patch.object(
rmsd_functions, "process_rmsf_results"
) as mocked_process_rmsf_results:
mocked_process_rmsf_results.return_value = None
rmsd_functions.compute_rmsf(selection="backbone", plot=plot)
mock_mda_universe.assert_called()
mocked_process_rmsf_results.assert_called_once()
args, kwargs = mocked_process_rmsf_results.call_args
selection = kwargs["selection"]
plot_arg = kwargs["plot"]
assert selection == "backbone"
assert plot_arg is plot

0 comments on commit c57cefe

Please sign in to comment.