Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Improve Notebook Rendering and Testing #444

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions example_apps/relaxation_trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import sys

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester

import crystal_toolkit.components as ctc
from crystal_toolkit.settings import SETTINGS

mp_id = "mp-1033715"
with MPRester(monty_decode=False) as mpr:
[task_doc] = mpr.tasks.search(task_ids=[mp_id])

steps = [
(
step.structure,
step.e_fr_energy,
np.linalg.norm(step.forces, axis=1).mean(),
)
for calc in reversed(task_doc.calcs_reversed)
for step in calc.output.ionic_steps
]
assert len(steps) == 99

e_col = "Energy (eV)"
force_col = "Force (eV/Å)"
spg_col = "Spacegroup"
struct_col = "Structure"

df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col])
df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info)


def plot_energy_and_forces(
df: pd.DataFrame,
step: int,
e_col: str,
force_col: str,
title: str,
) -> go.Figure:
"""Plot energy and forces as a function of relaxation step."""
fig = go.Figure()
# energy trace = primary y-axis
fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))

# forces trace = secondary y-axis
fig.add_trace(
go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
)

fig.update_layout(
template="plotly_white",
title=title,
xaxis={"title": "Relaxation Step"},
yaxis={"title": e_col},
yaxis2={"title": force_col, "overlaying": "y", "side": "right"},
legend=dict(yanchor="top", y=1, xanchor="right", x=1),
)

# vertical line at the specified step
fig.add_vline(x=step, line={"dash": "dash", "width": 1})

return fig


if "struct_comp" not in locals():
struct_comp = ctc.StructureMoleculeComponent(
id="structure", struct_or_mol=df_traj[struct_col][0]
)

step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps
slider = dcc.Slider(
id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag"
)


def make_title(spg: tuple[str, int]) -> str:
"""Return a title for the figure."""
href = f"https://materialsproject.org/materials/{mp_id}/"
return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})"


title = make_title(df_traj[spg_col][0])
graph = dcc.Graph(
id="fig",
figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
style={"maxWidth": "50%"},
)

app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
app.layout = html.Div(
[
html.H1(
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
),
html.P("Drag slider to see structure at different relaxation steps."),
slider,
html.Div(
[struct_comp.layout(), graph],
style=dict(display="flex", gap="2em", placeContent="center"),
),
],
style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
"""Update the structure displayed in the StructureMoleculeComponent and the
dashed vertical line in the figure when the slider is moved.
"""
title = make_title(df_traj[spg_col][step])
fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)

return df_traj[struct_col][step], fig


# https://stackoverflow.com/a/74918941
is_jupyter = "ipykernel" in sys.modules

if __name__ == "__main__":
app.run(port=8050, debug=True, use_reloader=not is_jupyter)
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def test_files():
return Path(__file__).parent / "test_files"


@pytest.fixture(scope="session")
def example_apps():
"""Return each `app` object defined in the files in the example_apps directory."""
examples_dir = Path(__file__).parent.parent / "example_apps"
files = examples_dir.glob("*.py")
return [file for file in files]


@pytest.fixture(scope="session")
def standard_scenes():
"""Dictionary of standard scenes for testing purposes."""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_example_apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from importlib import import_module
from pathlib import Path

import pytest
from dash import Dash


@pytest.fixture(scope="session")
def example_apps():
"""Return paths to example app files."""
examples_dir = Path(__file__).parent.parent / "example_apps"
return list(examples_dir.glob("*.py"))


def test_example_apps(example_apps):
"""Check each app is a valid Dash instance and can handle a basic request."""
for app_path in example_apps:
# Import the app module
relative_path = app_path.relative_to(app_path.parent.parent)
module_name = str(relative_path.with_suffix("")).replace("/", ".")
module = import_module(module_name)

# Check app exists and is a Dash app
app = getattr(module, "app", None)
assert app is not None, f"No 'app' object found in {app_path}"
assert isinstance(app, Dash), f"'app' object in {app_path} is not a Dash app"

# Use Flask's test client instead of running the server
with app.server.test_client() as client:
response = client.get("/")
assert response.status_code in (200, 302) # OK or redirect are both fine
Loading