From 7074c23fee75795d852cdc1f5aaf158bd147be4c Mon Sep 17 00:00:00 2001 From: Jimmy Shen Date: Mon, 20 Jan 2025 20:32:24 -0800 Subject: [PATCH 1/2] moved examp apps --- example_apps/relaxation_trajectory.py | 132 ++++++++++++++++++++++++++ tests/conftest.py | 7 ++ tests/test_example_apps.py | 30 ++++++ 3 files changed, 169 insertions(+) create mode 100644 example_apps/relaxation_trajectory.py create mode 100644 tests/test_example_apps.py diff --git a/example_apps/relaxation_trajectory.py b/example_apps/relaxation_trajectory.py new file mode 100644 index 00000000..c523940b --- /dev/null +++ b/example_apps/relaxation_trajectory.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import sys + +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from dash import Dash, dcc, html +from dash.dependencies import Input, Output +from pymatgen.core import Structure +from pymatgen.ext.matproj import MPRester + +import crystal_toolkit.components as ctc +from crystal_toolkit.settings import SETTINGS + +mp_id = "mp-1033715" +with MPRester(monty_decode=False) as mpr: + [task_doc] = mpr.tasks.search(task_ids=[mp_id]) + +steps = [ + ( + step.structure, + step.e_fr_energy, + np.linalg.norm(step.forces, axis=1).mean(), + ) + for calc in reversed(task_doc.calcs_reversed) + for step in calc.output.ionic_steps +] +assert len(steps) == 99 + +e_col = "Energy (eV)" +force_col = "Force (eV/Å)" +spg_col = "Spacegroup" +struct_col = "Structure" + +df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col]) +df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info) + + +def plot_energy_and_forces( + df: pd.DataFrame, + step: int, + e_col: str, + force_col: str, + title: str, +) -> go.Figure: + """Plot energy and forces as a function of relaxation step.""" + fig = go.Figure() + # energy trace = primary y-axis + fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy")) + + # forces trace = secondary y-axis + fig.add_trace( + go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2") + ) + + fig.update_layout( + template="plotly_white", + title=title, + xaxis={"title": "Relaxation Step"}, + yaxis={"title": e_col}, + yaxis2={"title": force_col, "overlaying": "y", "side": "right"}, + legend=dict(yanchor="top", y=1, xanchor="right", x=1), + ) + + # vertical line at the specified step + fig.add_vline(x=step, line={"dash": "dash", "width": 1}) + + return fig + + +if "struct_comp" not in locals(): + struct_comp = ctc.StructureMoleculeComponent( + id="structure", struct_or_mol=df_traj[struct_col][0] + ) + +step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps +slider = dcc.Slider( + id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag" +) + + +def make_title(spg: tuple[str, int]) -> str: + """Return a title for the figure.""" + href = f"https://materialsproject.org/materials/{mp_id}/" + return f"{mp_id} - {spg[0]} ({spg[1]})" + + +title = make_title(df_traj[spg_col][0]) +graph = dcc.Graph( + id="fig", + figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title), + style={"maxWidth": "50%"}, +) + +app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH) +app.layout = html.Div( + [ + html.H1( + "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em") + ), + html.P("Drag slider to see structure at different relaxation steps."), + slider, + html.Div( + [struct_comp.layout(), graph], + style=dict(display="flex", gap="2em", placeContent="center"), + ), + ], + style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"), +) + +ctc.register_crystal_toolkit(app=app, layout=app.layout) + + +@app.callback( + Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value") +) +def update_structure(step: int) -> tuple[Structure, go.Figure]: + """Update the structure displayed in the StructureMoleculeComponent and the + dashed vertical line in the figure when the slider is moved. + """ + title = make_title(df_traj[spg_col][step]) + fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title) + + return df_traj[struct_col][step], fig + + +# https://stackoverflow.com/a/74918941 +is_jupyter = "ipykernel" in sys.modules + +if __name__ == "__main__": + app.run(port=8050, debug=True, use_reloader=not is_jupyter) diff --git a/tests/conftest.py b/tests/conftest.py index 8f83768f..0f039ad7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,13 @@ def test_files(): """The path to the test_files directory.""" return Path(__file__).parent / "test_files" +@pytest.fixture(scope="session") +def example_apps(): + """Return each `app` object defined in the files in the example_apps directory.""" + examples_dir = Path(__file__).parent.parent / "example_apps" + files = examples_dir.glob("*.py") + return [file for file in files] + @pytest.fixture(scope="session") def standard_scenes(): diff --git a/tests/test_example_apps.py b/tests/test_example_apps.py new file mode 100644 index 00000000..b1d9c4f2 --- /dev/null +++ b/tests/test_example_apps.py @@ -0,0 +1,30 @@ +from pathlib import Path +import pytest +from importlib import import_module +from dash import Dash + + +@pytest.fixture(scope="session") +def example_apps(): + """Return paths to example app files.""" + examples_dir = Path(__file__).parent.parent / "example_apps" + return list(examples_dir.glob("*.py")) + + +def test_example_apps(example_apps): + """Check each app is a valid Dash instance and can handle a basic request.""" + for app_path in example_apps: + # Import the app module + relative_path = app_path.relative_to(app_path.parent.parent) + module_name = str(relative_path.with_suffix('')).replace('/', '.') + module = import_module(module_name) + + # Check app exists and is a Dash app + app = getattr(module, 'app', None) + assert app is not None, f"No 'app' object found in {app_path}" + assert isinstance(app, Dash), f"'app' object in {app_path} is not a Dash app" + + # Use Flask's test client instead of running the server + with app.server.test_client() as client: + response = client.get('/') + assert response.status_code in (200, 302) # OK or redirect are both fine \ No newline at end of file From 44d83c00541f36435a0d3116ec53aa92e049eedb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 04:34:31 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 1 + tests/test_example_apps.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0f039ad7..d92bf67d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ def test_files(): """The path to the test_files directory.""" return Path(__file__).parent / "test_files" + @pytest.fixture(scope="session") def example_apps(): """Return each `app` object defined in the files in the example_apps directory.""" diff --git a/tests/test_example_apps.py b/tests/test_example_apps.py index b1d9c4f2..9d9e2585 100644 --- a/tests/test_example_apps.py +++ b/tests/test_example_apps.py @@ -1,6 +1,7 @@ +from importlib import import_module from pathlib import Path + import pytest -from importlib import import_module from dash import Dash @@ -16,15 +17,15 @@ def test_example_apps(example_apps): for app_path in example_apps: # Import the app module relative_path = app_path.relative_to(app_path.parent.parent) - module_name = str(relative_path.with_suffix('')).replace('/', '.') + module_name = str(relative_path.with_suffix("")).replace("/", ".") module = import_module(module_name) - + # Check app exists and is a Dash app - app = getattr(module, 'app', None) + app = getattr(module, "app", None) assert app is not None, f"No 'app' object found in {app_path}" assert isinstance(app, Dash), f"'app' object in {app_path} is not a Dash app" # Use Flask's test client instead of running the server with app.server.test_client() as client: - response = client.get('/') - assert response.status_code in (200, 302) # OK or redirect are both fine \ No newline at end of file + response = client.get("/") + assert response.status_code in (200, 302) # OK or redirect are both fine