Skip to content

Commit 2b949b7

Browse files
authored
Merge pull request #430 from materialsproject/hook_up_fig_with_struct_viewer-transform_id
`hook_up_fig_with_struct_viewer` rename kwarg `validate_id` to `transform_id`
2 parents 05b691b + 8c862e5 commit 2b949b7

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

crystal_toolkit/helpers/utils.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import urllib.parse
66
from fractions import Fraction
7-
from typing import TYPE_CHECKING, Any, Callable
7+
from typing import TYPE_CHECKING, Any, Callable, Literal
88
from uuid import uuid4
99

1010
import dash
@@ -519,7 +519,7 @@ def hook_up_fig_with_struct_viewer(
519519
fig: go.Figure,
520520
df: pd.DataFrame,
521521
struct_col: str = "structure",
522-
validate_id: Callable[[str], bool] = lambda id: True,
522+
transform_id: Callable[[str], str | Literal[False]] = lambda mat_id: mat_id,
523523
highlight_selected: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
524524
) -> Dash:
525525
"""Create a Dash app that hooks up a Plotly figure with a Crystal Toolkit structure
@@ -555,11 +555,10 @@ def hook_up_fig_with_struct_viewer(
555555
struct_col (str, optional): Name of the column in the data frame that contains
556556
the structures. Defaults to 'structure'. Can be instances of
557557
pymatgen.core.Structure or dicts created with Structure.as_dict().
558-
validate_id (Callable[[str], bool], optional): Function that takes a string
558+
transform_id (Callable[[str], str | False], optional): Function that takes a string
559559
extracted from the hovertext key of a hoverData event payload and returns
560-
True if the string is a valid df row index. Defaults to lambda
561-
id: True. Useful for not running the update-structure
562-
callback on unexpected data.
560+
a string that can be used to index the dataframe. Return False to prevent
561+
the update-structure callback from being called.
563562
highlight_selected (Callable[[dict[str, Any]], dict[str, Any]], optional):
564563
Function that takes the clicked or last-hovered point and returns a dict of
565564
kwargs to be passed to go.Figure.add_annotation() to highlight said point.
@@ -642,9 +641,8 @@ def update_structure(
642641

643642
# hover_data and click_data are identical since a hover event always precedes a
644643
# click so we always use hover_data
645-
material_id = hover_data["points"][0]["hovertext"]
646-
if not validate_id(material_id):
647-
print(f"bad {material_id=}")
644+
material_id = transform_id(hover_data["points"][0]["hovertext"])
645+
if material_id is False:
648646
raise dash.exceptions.PreventUpdate
649647

650648
struct = df[struct_col][material_id]

0 commit comments

Comments
 (0)