|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import sys |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +import plotly.graph_objects as go |
| 8 | +from dash import Dash, dcc, html |
| 9 | +from dash.dependencies import Input, Output |
| 10 | +from pymatgen.core import Structure |
| 11 | +from pymatgen.ext.matproj import MPRester |
| 12 | + |
| 13 | +import crystal_toolkit.components as ctc |
| 14 | +from crystal_toolkit.settings import SETTINGS |
| 15 | + |
| 16 | +mp_id = "mp-1033715" |
| 17 | +with MPRester(monty_decode=False) as mpr: |
| 18 | + [task_doc] = mpr.tasks.search(task_ids=[mp_id]) |
| 19 | + |
| 20 | +steps = [ |
| 21 | + ( |
| 22 | + step.structure, |
| 23 | + step.e_fr_energy, |
| 24 | + np.linalg.norm(step.forces, axis=1).mean(), |
| 25 | + ) |
| 26 | + for calc in reversed(task_doc.calcs_reversed) |
| 27 | + for step in calc.output.ionic_steps |
| 28 | +] |
| 29 | +assert len(steps) == 99 |
| 30 | + |
| 31 | +e_col = "Energy (eV)" |
| 32 | +force_col = "Force (eV/Å)" |
| 33 | +spg_col = "Spacegroup" |
| 34 | +struct_col = "Structure" |
| 35 | + |
| 36 | +df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col]) |
| 37 | +df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info) |
| 38 | + |
| 39 | + |
| 40 | +def plot_energy_and_forces( |
| 41 | + df: pd.DataFrame, |
| 42 | + step: int, |
| 43 | + e_col: str, |
| 44 | + force_col: str, |
| 45 | + title: str, |
| 46 | +) -> go.Figure: |
| 47 | + """Plot energy and forces as a function of relaxation step.""" |
| 48 | + fig = go.Figure() |
| 49 | + # energy trace = primary y-axis |
| 50 | + fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy")) |
| 51 | + |
| 52 | + # forces trace = secondary y-axis |
| 53 | + fig.add_trace( |
| 54 | + go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2") |
| 55 | + ) |
| 56 | + |
| 57 | + fig.update_layout( |
| 58 | + template="plotly_white", |
| 59 | + title=title, |
| 60 | + xaxis={"title": "Relaxation Step"}, |
| 61 | + yaxis={"title": e_col}, |
| 62 | + yaxis2={"title": force_col, "overlaying": "y", "side": "right"}, |
| 63 | + legend=dict(yanchor="top", y=1, xanchor="right", x=1), |
| 64 | + ) |
| 65 | + |
| 66 | + # vertical line at the specified step |
| 67 | + fig.add_vline(x=step, line={"dash": "dash", "width": 1}) |
| 68 | + |
| 69 | + return fig |
| 70 | + |
| 71 | + |
| 72 | +if "struct_comp" not in locals(): |
| 73 | + struct_comp = ctc.StructureMoleculeComponent( |
| 74 | + id="structure", struct_or_mol=df_traj[struct_col][0] |
| 75 | + ) |
| 76 | + |
| 77 | +step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps |
| 78 | +slider = dcc.Slider( |
| 79 | + id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag" |
| 80 | +) |
| 81 | + |
| 82 | + |
| 83 | +def make_title(spg: tuple[str, int]) -> str: |
| 84 | + """Return a title for the figure.""" |
| 85 | + href = f"https://materialsproject.org/materials/{mp_id}/" |
| 86 | + return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})" |
| 87 | + |
| 88 | + |
| 89 | +title = make_title(df_traj[spg_col][0]) |
| 90 | +graph = dcc.Graph( |
| 91 | + id="fig", |
| 92 | + figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title), |
| 93 | + style={"maxWidth": "50%"}, |
| 94 | +) |
| 95 | + |
| 96 | +app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH) |
| 97 | +app.layout = html.Div( |
| 98 | + [ |
| 99 | + html.H1( |
| 100 | + "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em") |
| 101 | + ), |
| 102 | + html.P("Drag slider to see structure at different relaxation steps."), |
| 103 | + slider, |
| 104 | + html.Div( |
| 105 | + [struct_comp.layout(), graph], |
| 106 | + style=dict(display="flex", gap="2em", placeContent="center"), |
| 107 | + ), |
| 108 | + ], |
| 109 | + style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"), |
| 110 | +) |
| 111 | + |
| 112 | +ctc.register_crystal_toolkit(app=app, layout=app.layout) |
| 113 | + |
| 114 | + |
| 115 | +@app.callback( |
| 116 | + Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value") |
| 117 | +) |
| 118 | +def update_structure(step: int) -> tuple[Structure, go.Figure]: |
| 119 | + """Update the structure displayed in the StructureMoleculeComponent and the |
| 120 | + dashed vertical line in the figure when the slider is moved. |
| 121 | + """ |
| 122 | + title = make_title(df_traj[spg_col][step]) |
| 123 | + fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title) |
| 124 | + |
| 125 | + return df_traj[struct_col][step], fig |
| 126 | + |
| 127 | + |
| 128 | +# https://stackoverflow.com/a/74918941 |
| 129 | +is_jupyter = "ipykernel" in sys.modules |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + app.run(port=8050, debug=True, use_reloader=not is_jupyter) |
0 commit comments