Skip to content

Commit 7074c23

Browse files
committed
moved examp apps
1 parent c51ea91 commit 7074c23

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

example_apps/relaxation_trajectory.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
5+
import numpy as np
6+
import pandas as pd
7+
import plotly.graph_objects as go
8+
from dash import Dash, dcc, html
9+
from dash.dependencies import Input, Output
10+
from pymatgen.core import Structure
11+
from pymatgen.ext.matproj import MPRester
12+
13+
import crystal_toolkit.components as ctc
14+
from crystal_toolkit.settings import SETTINGS
15+
16+
mp_id = "mp-1033715"
17+
with MPRester(monty_decode=False) as mpr:
18+
[task_doc] = mpr.tasks.search(task_ids=[mp_id])
19+
20+
steps = [
21+
(
22+
step.structure,
23+
step.e_fr_energy,
24+
np.linalg.norm(step.forces, axis=1).mean(),
25+
)
26+
for calc in reversed(task_doc.calcs_reversed)
27+
for step in calc.output.ionic_steps
28+
]
29+
assert len(steps) == 99
30+
31+
e_col = "Energy (eV)"
32+
force_col = "Force (eV/Å)"
33+
spg_col = "Spacegroup"
34+
struct_col = "Structure"
35+
36+
df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col])
37+
df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info)
38+
39+
40+
def plot_energy_and_forces(
41+
df: pd.DataFrame,
42+
step: int,
43+
e_col: str,
44+
force_col: str,
45+
title: str,
46+
) -> go.Figure:
47+
"""Plot energy and forces as a function of relaxation step."""
48+
fig = go.Figure()
49+
# energy trace = primary y-axis
50+
fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))
51+
52+
# forces trace = secondary y-axis
53+
fig.add_trace(
54+
go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
55+
)
56+
57+
fig.update_layout(
58+
template="plotly_white",
59+
title=title,
60+
xaxis={"title": "Relaxation Step"},
61+
yaxis={"title": e_col},
62+
yaxis2={"title": force_col, "overlaying": "y", "side": "right"},
63+
legend=dict(yanchor="top", y=1, xanchor="right", x=1),
64+
)
65+
66+
# vertical line at the specified step
67+
fig.add_vline(x=step, line={"dash": "dash", "width": 1})
68+
69+
return fig
70+
71+
72+
if "struct_comp" not in locals():
73+
struct_comp = ctc.StructureMoleculeComponent(
74+
id="structure", struct_or_mol=df_traj[struct_col][0]
75+
)
76+
77+
step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps
78+
slider = dcc.Slider(
79+
id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag"
80+
)
81+
82+
83+
def make_title(spg: tuple[str, int]) -> str:
84+
"""Return a title for the figure."""
85+
href = f"https://materialsproject.org/materials/{mp_id}/"
86+
return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})"
87+
88+
89+
title = make_title(df_traj[spg_col][0])
90+
graph = dcc.Graph(
91+
id="fig",
92+
figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
93+
style={"maxWidth": "50%"},
94+
)
95+
96+
app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
97+
app.layout = html.Div(
98+
[
99+
html.H1(
100+
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
101+
),
102+
html.P("Drag slider to see structure at different relaxation steps."),
103+
slider,
104+
html.Div(
105+
[struct_comp.layout(), graph],
106+
style=dict(display="flex", gap="2em", placeContent="center"),
107+
),
108+
],
109+
style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"),
110+
)
111+
112+
ctc.register_crystal_toolkit(app=app, layout=app.layout)
113+
114+
115+
@app.callback(
116+
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
117+
)
118+
def update_structure(step: int) -> tuple[Structure, go.Figure]:
119+
"""Update the structure displayed in the StructureMoleculeComponent and the
120+
dashed vertical line in the figure when the slider is moved.
121+
"""
122+
title = make_title(df_traj[spg_col][step])
123+
fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)
124+
125+
return df_traj[struct_col][step], fig
126+
127+
128+
# https://stackoverflow.com/a/74918941
129+
is_jupyter = "ipykernel" in sys.modules
130+
131+
if __name__ == "__main__":
132+
app.run(port=8050, debug=True, use_reloader=not is_jupyter)

tests/conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ def test_files():
1313
"""The path to the test_files directory."""
1414
return Path(__file__).parent / "test_files"
1515

16+
@pytest.fixture(scope="session")
17+
def example_apps():
18+
"""Return each `app` object defined in the files in the example_apps directory."""
19+
examples_dir = Path(__file__).parent.parent / "example_apps"
20+
files = examples_dir.glob("*.py")
21+
return [file for file in files]
22+
1623

1724
@pytest.fixture(scope="session")
1825
def standard_scenes():

tests/test_example_apps.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from pathlib import Path
2+
import pytest
3+
from importlib import import_module
4+
from dash import Dash
5+
6+
7+
@pytest.fixture(scope="session")
8+
def example_apps():
9+
"""Return paths to example app files."""
10+
examples_dir = Path(__file__).parent.parent / "example_apps"
11+
return list(examples_dir.glob("*.py"))
12+
13+
14+
def test_example_apps(example_apps):
15+
"""Check each app is a valid Dash instance and can handle a basic request."""
16+
for app_path in example_apps:
17+
# Import the app module
18+
relative_path = app_path.relative_to(app_path.parent.parent)
19+
module_name = str(relative_path.with_suffix('')).replace('/', '.')
20+
module = import_module(module_name)
21+
22+
# Check app exists and is a Dash app
23+
app = getattr(module, 'app', None)
24+
assert app is not None, f"No 'app' object found in {app_path}"
25+
assert isinstance(app, Dash), f"'app' object in {app_path} is not a Dash app"
26+
27+
# Use Flask's test client instead of running the server
28+
with app.server.test_client() as client:
29+
response = client.get('/')
30+
assert response.status_code in (200, 302) # OK or redirect are both fine

0 commit comments

Comments
 (0)