Skip to content

Commit

Permalink
redoing chi angle analysis for computeangle tools
Browse files Browse the repository at this point in the history
  • Loading branch information
Jgmedina95 committed Jan 27, 2025
1 parent 15b7ccc commit 52aae41
Showing 1 changed file with 141 additions and 77 deletions.
218 changes: 141 additions & 77 deletions mdagent/tools/base_tools/analysis_tools/bond_angles_dihedrals_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class ComputingAnglesSchema(BaseModel):
description=(
"Which analysis to be done. Availables are: "
"phi-psi (saves a Ramachandran plot and histograms for the Phi-Psi angles),"
"chi1-chi2 (gets the chi1 and chi2 dihedral angles and the chi1-chi2 plot"
"is saved. For the plots it only uses sidechains with enough carbons),"
"all (makes all of the previous analysis)"
"chis (gets the chis 1-4 angles and plots a time evolutiuon plot for all"
"residues is saved. For the plots it only uses sidechains with enough "
"carbons), all (makes all of the previous analysis)"
),
)
# This arg is here, but is not used in the code. As of now it will get the analysis
Expand Down Expand Up @@ -164,95 +164,159 @@ def compute_and_plot_phi_psi(self, traj, path_registry, sim_id):
"Succeeded. Computed phi-psi angles (no path_registry to save).",
)

def compute_and_plot_chi1_chi2(self, traj, path_registry, sim_id):
def classify_chi(self, ang_deg, res_name=""):
"""Return an integer code depending on angle range."""
# Example classification with made-up intervals:
if res_name == "PRO" or res_name == "P":
if ang_deg < 0:
return 3 # e.g. "p-"
else:
return 4 # e.g. "p+"
# angles for g+
if 0 <= ang_deg < 120:
return 0 # e.g. "g+"
# angles for t
elif -120 >= ang_deg or ang_deg > 120:
return 1 # e.g. "t"
# angles for g-
elif -120 <= ang_deg < 0:
return 2 # e.g. "g-"

# function that takes an array and classifies the angles
def classify_chi_angles(self, angles, res_name=""):
return [self.classify_chi(ang, res_name) for ang in angles]

def _plot_one_chi_angle(self, ax, angle_array, residue_names, title=None):
"""
Computes chi1-chi2 angles, saves results to file, and produces Chi1-Chi2 plot.
Classify angles per residue/frame, then do imshow on a given Axes.
angle_array: shape (n_frames, n_residues) or (n_residues, n_frames)
residue_names: length n_residues
"""
try:
# Compute chi1 and chi2 angles
chi1_indices, chi1_angles = md.compute_chi1(traj)
chi2_indices, chi2_angles = md.compute_chi2(traj)

# Convert angles to degrees
chi1_angles = chi1_angles * (180.0 / np.pi)
chi2_angles = chi2_angles * (180.0 / np.pi)
except Exception as e:
return None, f"Failed. Error computing chi1-chi2 angles: {str(e)}"

# If path_registry is available, save files and produce plot
if path_registry is not None:
# Get the indices of the first side-chain atoms from chi1 and chi2
chi1_atoms = [atom_idx[1] for atom_idx in chi1_indices]
chi2_atoms = [atom_idx[0] for atom_idx in chi2_indices]

# Filter chi1 angles to match atoms that appear in chi2
chi1_angles_long = np.array(
[
chi1_angles[:, i]
for i, chi1_atom in enumerate(chi1_atoms)
if chi1_atom in chi2_atoms
]
)

# Save angle results
save_results_to_file("chi1_results.npz", chi1_indices, chi1_angles)
save_results_to_file("chi2_results.npz", chi2_indices, chi2_angles)

# Make Chi1-Chi2 plot
try:
plt.hist2d(
chi1_angles_long.T.flatten(),
chi2_angles.flatten(),
bins=200,
cmap="Blues",
)
plt.xlabel(r"$\chi1$")
plt.ylabel(r"$\chi2$")
plt.title(f"Chi1-Chi2 plot for the simulation {sim_id}")
plt.colorbar()

file_name = path_registry.write_file_name(
FileType.FIGURE,
fig_analysis="chi1-chi2",
file_format="png",
Sim_id=sim_id,
)
desc = f"Chi1-Chi2 plot for the simulation {sim_id}"
chi_plot_id = path_registry.get_fileid(file_name, FileType.FIGURE)
path = path_registry.ckpt_dir + "/figures/"
plt.savefig(path + file_name)
path_registry.map_path(chi_plot_id, path + file_name, description=desc)
plt.clf() # Clear the current figure so it does not overlay next plot
print("Chi1-Chi2 plot saved to file")
return chi_plot_id, "Succeeded. Chi1-Chi2 plot saved."
except Exception as e:
return None, f"Failed. Error saving Chi1-Chi2 plot: {str(e)}"
else:

return None, "Succeeded. Computed chi1-chi2 angles."
state_sequence = np.array(
[
[self.classify_chi_angles(a, str(name)[:3])]
for i, (a, name) in enumerate(zip(angle_array.T, residue_names))
]
)
states_per_res = state_sequence.reshape(
state_sequence.shape[0], state_sequence.shape[2]
) # shape = (#res,1, #frames)
# -> (#res, #frames)

n_residues = len(residue_names)
unique_states = np.unique(states_per_res)
n_states = len(unique_states)
cmap = plt.get_cmap("tab20", n_states)

im = ax.imshow(
states_per_res,
aspect="auto",
interpolation="none",
cmap=cmap,
origin="upper",
)

ax.set_xlabel("Frame index")
ax.set_ylabel("Residue")
if title:
ax.set_title(title)

ax.set_yticks(np.arange(n_residues))
ax.set_yticklabels([str(r) for r in residue_names], fontsize=8)

cbar = plt.colorbar(im, ax=ax, ticks=range(n_states), pad=0.01)

# Example state -> label mapping
state_labels_map = {0: "g+", 1: "t", 2: "g-", 3: "Cγ endo", 4: "Cγ exo"}
tick_labels = [state_labels_map.get(s, f"State {s}") for s in unique_states]
cbar.ax.set_yticklabels(tick_labels, fontsize=8)

###################################################
# Main function to produce a single figure w/ 4 subplots
###################################################
def compute_plot_all_chi_angles(self, traj, sim_id="sim"):
"""
Create one figure with 4 subplots (2x2):
- subplot(0,0): χ1
- subplot(0,1): χ2
- subplot(1,0): χ3
- subplot(1,1): χ4
"""
chi1_indices, chi_1_angles = md.compute_chi1(traj)
chi2_indices, chi_2_angles = md.compute_chi2(traj)
chi3_indices, chi_3_angles = md.compute_chi3(traj)
chi4_indices, chi_4_angles = md.compute_chi4(traj)

chi_1_angles_degrees = np.rad2deg(chi_1_angles)
chi_2_angles_degrees = np.rad2deg(chi_2_angles)
chi_3_angles_degrees = np.rad2deg(chi_3_angles)
chi_4_angles_degrees = np.rad2deg(chi_4_angles)
residue_names_1 = [traj.topology.atom(i).residue for i in chi1_indices[:, 1]]
residue_names_2 = [traj.topology.atom(i).residue for i in chi2_indices[:, 1]]
residue_names_3 = [traj.topology.atom(i).residue for i in chi3_indices[:, 1]]
residue_names_4 = [traj.topology.atom(i).residue for i in chi4_indices[:, 1]]
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Top-left: χ1
self._plot_one_chi_angle(
axes[0, 0], chi_1_angles_degrees, residue_names_1, title=r"$\chi$1"
)

# Top-right: χ2
self._plot_one_chi_angle(
axes[0, 1], chi_2_angles_degrees, residue_names_2, title="$\chi$2"
)

# Bottom-left: χ3
self._plot_one_chi_angle(
axes[1, 0], chi_3_angles_degrees, residue_names_3, title="$\chi$3"
)

# Bottom-right: χ4
self._plot_one_chi_angle(
axes[1, 1], chi_4_angles_degrees, residue_names_4, title="$\chi$4"
)
# add title
fig.suptitle("Chi angles per residue for simulation {sim}", fontsize=16)
plt.tight_layout()
# plt.show()
# Save the figure
file_name = self.path_registry.write_file_name(
FileType.FIGURE,
fig_analysis="chi_angles",
file_format="png",
Sim_id=sim_id,
)
desc = f"Chi angles plot for the simulation {sim_id}"
plot_id = self.path_registry.get_fileid(file_name, FileType.FIGURE)
path = self.path_registry.ckpt_dir + "/figures/"
plt.savefig(path + file_name)
self.path_registry.map_path(plot_id, path + file_name, description=desc)
plt.clf() # Clear the current figure so it does not overlay next plot
return plot_id, "Succeeded. Chi angles plot saved."

def analyze_trajectory(self, traj, analysis, path_registry=None, sim_id="sim"):
"""
Main function to decide which analysis to do:
'phi-psi', 'chi1-chi2', or 'all'.
'phi-psi', 'chis', or 'all'.
"""
# Store optional references for convenience
self_path_registry = path_registry
self_sim_id = sim_id

# ================ PHI-PSI ONLY =================
if analysis == "phi-psi":
plot_id, message = self.compute_and_plot_phi_psi(
ram_plot_id, phi_message = self.compute_and_plot_phi_psi(
traj, self_path_registry, self_sim_id
)
return message
return f"Ramachandran plot with ID {ram_plot_id}, message: {phi_message} "

# ================ CHI1-CHI2 ONLY ================
elif analysis == "chi1-chi2":
plot_id, message = self.compute_and_plot_chi1_chi2(
elif analysis == "chis":
chi_plot_id, chi_message = self.compute_plot_all_chi_angles(
traj, self_path_registry, self_sim_id
)
return message
return f"Chis plot with ID {chi_plot_id}, message: {chi_message}"

# ================ ALL =================
elif analysis == "all":
Expand All @@ -264,16 +328,16 @@ def analyze_trajectory(self, traj, analysis, path_registry=None, sim_id="sim"):
return phi_message

# Then do chi1-chi2
chi_plot_id, chi_message = self.compute_and_plot_chi1_chi2(
chi_plot_id, chi_message = self.compute_plot_all_chi_angles(
traj, self_path_registry, self_sim_id
)
if "Failed." in chi_message:
return chi_message

return (
"Succeeded. All analyses completed. "
f"Ramachandran plot message: {phi_message} "
f"Chi1-Chi2 plot message: {chi_message}"
f"Ramachandran plot with ID {ram_plot_id}, message: {phi_message} "
f"Chis plot with ID {chi_plot_id}, message: {chi_message}"
)

else:
Expand Down Expand Up @@ -303,7 +367,7 @@ def validate_input(self, **input):
if analysis.lower() not in [
"all",
"phi-psi",
"chi1-chi2",
"chis",
]:
analysis = "all"
system_message += (
Expand Down

0 comments on commit 52aae41

Please sign in to comment.