Skip to content

Commit acec949

Browse files
committed
hook_up_fig_with_struct_viewer modify go.Figure annotations in place instead of creating copy to maintain pan and zoom state
- add several unit tests for the updated function
1 parent 858ac96 commit acec949

File tree

2 files changed

+276
-6
lines changed

2 files changed

+276
-6
lines changed

crystal_toolkit/helpers/utils.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def hook_up_fig_with_struct_viewer(
624624
Input(graph, "hoverData"),
625625
Input(graph, "clickData"),
626626
State(hover_click_dd, "value"),
627+
State(graph, "figure"),
627628
)
628629
def update_structure(
629630
hover_data: dict[str, list[dict[str, Any]]],
630631
click_data: dict[str, list[dict[str, Any]]], # needed only as callback trigger
631632
dropdown_value: str,
632-
) -> tuple[Structure, str, go.Figure] | tuple[None, None, None]:
633+
fig: dict[str, Any],
634+
) -> tuple[Structure, str, dict[str, Any]] | tuple[None, None, None]:
633635
"""Update StructureMoleculeComponent with pymatgen structure when user clicks or
634636
hovers a plot point.
635637
"""
@@ -651,13 +653,20 @@ def update_structure(
651653
struct_title = f"{material_id} ({struct.formula})"
652654

653655
if highlight_selected is not None:
654-
# remove existing annotations with name="selected"
655-
fig.layout.annotations = [
656-
anno for anno in fig.layout.annotations if anno.name != "selected"
656+
# Update annotations directly in the dictionary
657+
fig["layout"].setdefault("annotations", [])
658+
659+
# Remove existing annotations with name="selected"
660+
fig["layout"]["annotations"] = [
661+
anno
662+
for anno in fig["layout"]["annotations"]
663+
if anno.get("name") != "selected"
657664
]
658-
# highlight selected point in figure
665+
666+
# Add new annotation to highlight selected point
659667
anno = highlight_selected(hover_data["points"][0])
660-
fig.add_annotation(**anno, name="selected")
668+
anno["name"] = "selected"
669+
fig["layout"]["annotations"].append(anno)
661670

662671
return struct, struct_title, fig
663672

tests/test_utils.py

+261
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
from __future__ import annotations
2+
3+
import pandas as pd
4+
import plotly.express as px
5+
import plotly.graph_objects as go
6+
import pytest
7+
from dash import Dash, Output
8+
from pymatgen.core import Structure
9+
10+
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer
11+
12+
13+
@pytest.fixture()
14+
def sample_df() -> pd.DataFrame:
15+
"""Create sample data for testing."""
16+
# Create a simple structure
17+
from pymatgen.core import Lattice
18+
19+
struct = Structure(
20+
lattice=Lattice.cubic(3),
21+
species=("Fe", "Fe"),
22+
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
23+
)
24+
25+
# Create a DataFrame with some sample data
26+
return pd.DataFrame(
27+
{
28+
"material_id": ["mp-1", "mp-2"],
29+
"nsites": [2, 4],
30+
"volume": [10, 20],
31+
"structure": [struct, struct],
32+
}
33+
).set_index("material_id", drop=False)
34+
35+
36+
@pytest.fixture()
37+
def fig(sample_df: pd.DataFrame) -> go.Figure:
38+
# Create a simple scatter plot
39+
return px.scatter(
40+
sample_df, x="nsites", y="volume", hover_name=sample_df.index.name
41+
)
42+
43+
44+
def test_basic_functionality(fig: go.Figure, sample_df: pd.DataFrame):
45+
"""Test that the function creates a Dash app with the expected components."""
46+
app = hook_up_fig_with_struct_viewer(fig, sample_df)
47+
48+
# Check that the app was created
49+
assert isinstance(app, Dash)
50+
51+
# Check that the layout contains expected components
52+
layout = app.layout
53+
assert layout is not None
54+
assert "plot" in str(layout)
55+
assert "structure" in str(layout)
56+
assert "hover-click-dropdown" in str(layout)
57+
58+
59+
def test_callback_behavior(fig: go.Figure, sample_df: pd.DataFrame):
60+
"""Test that the callback updates the structure and annotations correctly."""
61+
app = hook_up_fig_with_struct_viewer(fig, sample_df)
62+
63+
# Create sample hover data
64+
hover_data = {"points": [{"x": 2, "y": 10, "hovertext": "mp-1"}]}
65+
66+
# Find the callback that has plot.figure as an output
67+
callback_key = None
68+
for key, value in app.callback_map.items():
69+
output = value.get("output", [])
70+
outputs = [output] if isinstance(output, Output) else output
71+
72+
if any(
73+
isinstance(output, Output)
74+
and output.component_id == "plot"
75+
and output.component_property == "figure"
76+
for output in outputs
77+
):
78+
callback_key = key
79+
break
80+
81+
assert callback_key.endswith("struct-title.children...plot.figure..")
82+
callback = app.callback_map[callback_key]["callback"]
83+
84+
# Get the input and state definitions
85+
inputs = app.callback_map[callback_key]["inputs"]
86+
states = app.callback_map[callback_key]["state"]
87+
88+
# Create the input arguments in the correct order
89+
args = []
90+
for input_def in inputs:
91+
if input_def["property"] == "hoverData":
92+
args.append(hover_data)
93+
elif input_def["property"] == "clickData":
94+
args.append(None)
95+
else:
96+
raise ValueError(f"Unexpected input property: {input_def['property']}")
97+
98+
# Add state arguments in the correct order
99+
for state_def in states:
100+
if state_def["property"] == "value":
101+
args.append("hover")
102+
elif state_def["property"] == "figure":
103+
args.append(fig.to_dict())
104+
else:
105+
raise ValueError(f"Unexpected state property: {state_def['property']}")
106+
107+
# Convert Output objects to dictionaries for outputs_list
108+
outputs = app.callback_map[callback_key]["output"]
109+
if isinstance(outputs, Output):
110+
outputs = [outputs]
111+
outputs_list = [
112+
{"id": output.component_id, "property": output.component_property}
113+
for output in outputs
114+
]
115+
116+
# Call the callback with the arguments in the correct order and outputs_list as a keyword argument
117+
result = callback(*args, outputs_list=outputs_list)
118+
119+
# Basic assertion that we got a result
120+
assert result.startswith('{"multi":true,"response"')
121+
122+
123+
def test_click_mode(fig: go.Figure, sample_df: pd.DataFrame):
124+
"""Test that the callback respects the click mode setting."""
125+
app = hook_up_fig_with_struct_viewer(fig, sample_df)
126+
127+
# Create sample hover data
128+
hover_data = {"points": [{"x": 2, "y": 10, "hovertext": "mp-1"}]}
129+
130+
# Find the callback that has plot.figure as an output
131+
callback_key = None
132+
for key, value in app.callback_map.items():
133+
output = value.get("output", [])
134+
outputs = [output] if isinstance(output, Output) else output
135+
136+
if any(
137+
isinstance(output, Output)
138+
and output.component_id == "plot"
139+
and output.component_property == "figure"
140+
for output in outputs
141+
):
142+
callback_key = key
143+
break
144+
145+
assert callback_key.endswith("struct-title.children...plot.figure..")
146+
callback = app.callback_map[callback_key]["callback"]
147+
148+
# Get the input and state definitions
149+
inputs = app.callback_map[callback_key]["inputs"]
150+
states = app.callback_map[callback_key]["state"]
151+
152+
# Create the input arguments in the correct order
153+
args = []
154+
for input_def in inputs:
155+
if input_def["property"] == "hoverData":
156+
args.append(hover_data)
157+
elif input_def["property"] == "clickData":
158+
args.append(None)
159+
else:
160+
raise ValueError(f"Unexpected input property: {input_def['property']}")
161+
162+
# Add state arguments in the correct order
163+
for state_def in states:
164+
if state_def["property"] == "value":
165+
args.append("click")
166+
elif state_def["property"] == "figure":
167+
args.append(fig.to_dict())
168+
else:
169+
raise ValueError(f"Unexpected state property: {state_def['property']}")
170+
171+
# Convert Output objects to dictionaries for outputs_list
172+
outputs = app.callback_map[callback_key]["output"]
173+
if isinstance(outputs, Output):
174+
outputs = [outputs]
175+
outputs_list = [
176+
{"id": output.component_id, "property": output.component_property}
177+
for output in outputs
178+
]
179+
180+
# Call the callback with the arguments in the correct order and outputs_list as a keyword argument
181+
result = callback(*args, outputs_list=outputs_list)
182+
183+
# Basic assertion that we got a result
184+
assert result.startswith('{"multi":true,"response"')
185+
186+
187+
def test_custom_highlight(fig: go.Figure, sample_df: pd.DataFrame):
188+
"""Test that custom highlighting function works."""
189+
190+
def custom_highlight(point):
191+
return {
192+
"x": point["x"],
193+
"y": point["y"],
194+
"xref": "x",
195+
"yref": "y",
196+
"text": f"Custom: {point['hovertext']}",
197+
"showarrow": True,
198+
}
199+
200+
app = hook_up_fig_with_struct_viewer(
201+
fig, sample_df, highlight_selected=custom_highlight
202+
)
203+
204+
# Create sample hover data
205+
hover_data = {"points": [{"x": 2, "y": 10, "hovertext": "mp-1"}]}
206+
207+
# Find the callback that has plot.figure as an output
208+
callback_key = None
209+
for key, value in app.callback_map.items():
210+
output = value.get("output", [])
211+
outputs = [output] if isinstance(output, Output) else output
212+
213+
if any(
214+
isinstance(output, Output)
215+
and output.component_id == "plot"
216+
and output.component_property == "figure"
217+
for output in outputs
218+
):
219+
callback_key = key
220+
break
221+
222+
assert callback_key.endswith("struct-title.children...plot.figure..")
223+
callback = app.callback_map[callback_key]["callback"]
224+
225+
# Get the input and state definitions
226+
inputs = app.callback_map[callback_key]["inputs"]
227+
states = app.callback_map[callback_key]["state"]
228+
229+
# Create the input arguments in the correct order
230+
args = []
231+
for input_def in inputs:
232+
if input_def["property"] == "hoverData":
233+
args.append(hover_data)
234+
elif input_def["property"] == "clickData":
235+
args.append(None)
236+
else:
237+
raise ValueError(f"Unexpected input property: {input_def['property']}")
238+
239+
# Add state arguments in the correct order
240+
for state_def in states:
241+
if state_def["property"] == "value":
242+
args.append("hover")
243+
elif state_def["property"] == "figure":
244+
args.append(fig.to_dict())
245+
else:
246+
raise ValueError(f"Unexpected state property: {state_def['property']}")
247+
248+
# Convert Output objects to dictionaries for outputs_list
249+
outputs = app.callback_map[callback_key]["output"]
250+
if isinstance(outputs, Output):
251+
outputs = [outputs]
252+
outputs_list = [
253+
{"id": output.component_id, "property": output.component_property}
254+
for output in outputs
255+
]
256+
257+
# Call the callback with the arguments in the correct order and outputs_list as a keyword argument
258+
result = callback(*args, outputs_list=outputs_list)
259+
260+
# Basic assertion that we got a result
261+
assert result.startswith('{"multi":true,"response"')

0 commit comments

Comments
 (0)