Skip to content

Commit 0d3e20b

Browse files
committed
future annotations
1 parent f8045b0 commit 0d3e20b

File tree

13 files changed

+187
-181
lines changed

13 files changed

+187
-181
lines changed

crystal_toolkit/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from __future__ import annotations
2+
13
import json
24
import os as _os
35
from collections import defaultdict
46
from pathlib import Path
7+
from typing import Any
58

69
# pleasant hack to support MSONable objects in Dash callbacks natively
710
from monty.json import MSONable
@@ -30,7 +33,7 @@ def to_plotly_json(self):
3033

3134

3235
# Populate the default values from the JSON file
33-
_DEFAULTS = defaultdict()
36+
_DEFAULTS: dict[str, Any] = defaultdict()
3437
default_js = _os.path.join(
3538
_os.path.join(_os.path.dirname(_os.path.abspath(__file__))), "./", "defaults.json"
3639
)

crystal_toolkit/apps/main.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from __future__ import annotations
2+
13
import logging
24
import os
35
import warnings
46
from random import choice
57
from time import time
6-
from typing import Optional
78
from urllib import parse
89
from uuid import uuid4
910

@@ -448,7 +449,7 @@ def update_search_term_on_page_load(href: str) -> str:
448449
[Input(search_component.id("input"), "value")],
449450
[State(search_component.id("input"), "n_submit")],
450451
)
451-
def perform_search_on_page_load(search_term: str, n_submit: Optional[int]):
452+
def perform_search_on_page_load(search_term: str, n_submit: int | None):
452453
"""
453454
Loading with an mpid in the URL requires populating the search term with
454455
the mpid, this callback forces that search to then take place by force updating
@@ -468,7 +469,7 @@ def perform_search_on_page_load(search_term: str, n_submit: Optional[int]):
468469

469470

470471
@app.callback(Output("url", "pathname"), [Input(search_component.id(), "data")])
471-
def update_url_pathname_from_search_term(mpid: Optional[str]) -> str:
472+
def update_url_pathname_from_search_term(mpid: str | None) -> str:
472473
"""
473474
Updates the URL from the search term. Technically a circular callback,
474475
this is done to prevent the app seeming inconsistent from the end user.
@@ -488,7 +489,7 @@ def update_url_pathname_from_search_term(mpid: Optional[str]) -> str:
488489
Output(transformation_component.id("input_structure"), "data"),
489490
[Input(search_component.id(), "data"), Input(upload_component.id(), "data")],
490491
)
491-
def master_update_structure(search_mpid: Optional[str], upload_data: Optional[str]):
492+
def master_update_structure(search_mpid: str | None, upload_data: dict | None):
492493
"""
493494
A new structure is loaded either from the search component or from the
494495
upload component. This callback triggers the update, and uses the callback

crystal_toolkit/components/bandstructure.py

-10
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,9 @@
2626
Loading,
2727
MessageBody,
2828
MessageContainer,
29-
bandstructure_symm_line,
30-
bs,
31-
bsml,
3229
dcc,
33-
density_of_states,
34-
dos,
35-
dos_select,
36-
elements,
37-
get_bandstructure_traces,
3830
get_data_list,
39-
get_dos_traces,
4031
html,
41-
path_convention,
4232
)
4333

4434
# Author: Jason Munro

crystal_toolkit/components/structure.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from __future__ import annotations
2+
13
import re
24
import warnings
35
from base64 import b64encode
46
from collections import OrderedDict
57
from itertools import chain, combinations_with_replacement
68
from pathlib import Path
79
from tempfile import TemporaryDirectory
8-
from typing import Dict, Literal, Optional, Tuple, Union
10+
from typing import Literal
911

1012
import numpy as np
1113
from dash import dash_table as dt
@@ -29,7 +31,7 @@
2931

3032
# TODO: make dangling bonds "stubs"? (fixed length)
3133

32-
DEFAULTS = {
34+
DEFAULTS: dict[str, str | bool] = {
3335
"color_scheme": "VESTA",
3436
"bonding_strategy": "CrystalNN",
3537
"radius_strategy": "uniform",
@@ -85,16 +87,15 @@ class StructureMoleculeComponent(MPComponent):
8587

8688
def __init__(
8789
self,
88-
struct_or_mol: Optional[
89-
Union[Structure, StructureGraph, Molecule, MoleculeGraph]
90-
] = None,
90+
struct_or_mol: None
91+
| (Structure | StructureGraph | Molecule | MoleculeGraph) = None,
9192
id: str = None,
9293
className: str = "box",
93-
scene_additions: Optional[Scene] = None,
94+
scene_additions: Scene | None = None,
9495
bonding_strategy: str = DEFAULTS["bonding_strategy"],
95-
bonding_strategy_kwargs: Optional[dict] = None,
96+
bonding_strategy_kwargs: dict | None = None,
9697
color_scheme: str = DEFAULTS["color_scheme"],
97-
color_scale: Optional[str] = None,
98+
color_scale: str | None = None,
9899
radius_strategy: str = DEFAULTS["radius_strategy"],
99100
unit_cell_choice: str = DEFAULTS["unit_cell_choice"],
100101
draw_image_atoms: bool = DEFAULTS["draw_image_atoms"],
@@ -103,8 +104,8 @@ def __init__(
103104
],
104105
hide_incomplete_bonds: bool = DEFAULTS["hide_incomplete_bonds"],
105106
show_compass: bool = DEFAULTS["show_compass"],
106-
scene_settings: Optional[Dict] = None,
107-
group_by_site_property: Optional[str] = None,
107+
scene_settings: dict | None = None,
108+
group_by_site_property: str | None = None,
108109
show_legend: bool = DEFAULTS["show_legend"],
109110
show_settings: bool = DEFAULTS["show_settings"],
110111
show_controls: bool = DEFAULTS["show_controls"],
@@ -907,7 +908,7 @@ def layout(self, size: str = "500px") -> html.Div:
907908

908909
@staticmethod
909910
def _preprocess_structure(
910-
struct_or_mol: Union[Structure, StructureGraph, Molecule, MoleculeGraph],
911+
struct_or_mol: Structure | StructureGraph | Molecule | MoleculeGraph,
911912
unit_cell_choice: Literal[
912913
"input", "primitive", "conventional", "reduced_niggli", "reduced_lll"
913914
] = "input",
@@ -931,10 +932,10 @@ def _preprocess_structure(
931932

932933
@staticmethod
933934
def _preprocess_input_to_graph(
934-
input: Union[Structure, StructureGraph, Molecule, MoleculeGraph],
935+
input: Structure | StructureGraph | Molecule | MoleculeGraph,
935936
bonding_strategy: str = DEFAULTS["bonding_strategy"],
936-
bonding_strategy_kwargs: Optional[Dict] = None,
937-
) -> Union[StructureGraph, MoleculeGraph]:
937+
bonding_strategy_kwargs: dict | None = None,
938+
) -> StructureGraph | MoleculeGraph:
938939

939940
if isinstance(input, Structure):
940941

@@ -1007,8 +1008,8 @@ def _preprocess_input_to_graph(
10071008

10081009
@staticmethod
10091010
def _get_struct_or_mol(
1010-
graph: Union[StructureGraph, MoleculeGraph, Structure, Molecule]
1011-
) -> Union[Structure, Molecule]:
1011+
graph: StructureGraph | MoleculeGraph | Structure | Molecule,
1012+
) -> Structure | Molecule:
10121013
if isinstance(graph, StructureGraph):
10131014
return graph.structure
10141015
elif isinstance(graph, MoleculeGraph):
@@ -1020,7 +1021,7 @@ def _get_struct_or_mol(
10201021

10211022
@staticmethod
10221023
def get_scene_and_legend(
1023-
graph: Optional[Union[StructureGraph, MoleculeGraph]],
1024+
graph: StructureGraph | MoleculeGraph | None,
10241025
color_scheme=DEFAULTS["color_scheme"],
10251026
color_scale=None,
10261027
radius_strategy=DEFAULTS["radius_strategy"],
@@ -1031,7 +1032,7 @@ def get_scene_and_legend(
10311032
scene_additions=None,
10321033
show_compass=DEFAULTS["show_compass"],
10331034
group_by_site_property=None,
1034-
) -> Tuple[Scene, Dict[str, str]]:
1035+
) -> tuple[Scene, dict[str, str]]:
10351036

10361037
scene = Scene(name="StructureMoleculeComponentScene")
10371038

crystal_toolkit/components/transformations/core.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from __future__ import annotations
2+
13
import traceback
24
import warnings
3-
from typing import List, Optional
45

56
import dash
67
import dash_daq as daq
@@ -133,7 +134,7 @@ def container_layout(self, state=None, structure=None) -> html.Div:
133134

134135
return container
135136

136-
def options_layouts(self, state=None, structure=None) -> List[html.Div]:
137+
def options_layouts(self, state=None, structure=None) -> list[html.Div]:
137138
"""
138139
Return a layout to change the transformation options (that is,
139140
that controls the args and kwargs that will be passed to pymatgen).
@@ -270,8 +271,8 @@ def update_transformation(enabled, states):
270271
class AllTransformationsComponent(MPComponent):
271272
def __init__(
272273
self,
273-
transformations: Optional[List[str]] = None,
274-
input_structure_component: Optional[MPComponent] = None,
274+
transformations: list[str] | None = None,
275+
input_structure_component: MPComponent | None = None,
275276
*args,
276277
**kwargs,
277278
):

crystal_toolkit/core/legend.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import os
24
import warnings
35
from collections import defaultdict
46
from itertools import chain
5-
from typing import Any, Dict, List, Optional, Tuple, Union
7+
from typing import Any
68

79
import numpy as np
810
from matplotlib.cm import get_cmap
@@ -41,11 +43,11 @@ class Legend(MSONable):
4143

4244
def __init__(
4345
self,
44-
site_collection: Union[SiteCollection, Site],
46+
site_collection: SiteCollection | Site,
4547
color_scheme: str = "Jmol",
4648
radius_scheme: str = "uniform",
4749
cmap: str = "coolwarm",
48-
cmap_range: Optional[Tuple[float, float]] = None,
50+
cmap_range: tuple[float, float] | None = None,
4951
):
5052
"""
5153
Create a legend for a given SiteCollection to choose how to
@@ -137,7 +139,7 @@ def __init__(
137139
@staticmethod
138140
def generate_accessible_color_scheme_on_the_fly(
139141
site_collection: SiteCollection,
140-
) -> Dict[str, Dict[str, Tuple[int, int, int]]]:
142+
) -> dict[str, dict[str, tuple[int, int, int]]]:
141143
"""
142144
e.g. for a color scheme more appropriate for people with color blindness
143145
@@ -217,7 +219,7 @@ def generate_accessible_color_scheme_on_the_fly(
217219
@staticmethod
218220
def generate_categorical_color_scheme_on_the_fly(
219221
site_collection: SiteCollection, site_prop_types
220-
) -> Dict[str, Dict[str, Tuple[int, int, int]]]:
222+
) -> dict[str, dict[str, tuple[int, int, int]]]:
221223
"""
222224
e.g. for Wykcoff
223225
@@ -257,7 +259,7 @@ def generate_categorical_color_scheme_on_the_fly(
257259

258260
return color_scheme
259261

260-
def get_color(self, sp: Union[Specie, Element], site: Optional[Site] = None) -> str:
262+
def get_color(self, sp: Specie | Element, site: Site | None = None) -> str:
261263
"""
262264
Get a color to render a specific species. Optionally, you can provide
263265
a site for context, since ...
@@ -335,9 +337,7 @@ def get_color(self, sp: Union[Specie, Element], site: Optional[Site] = None) ->
335337

336338
return html5_serialize_simple_color(color)
337339

338-
def get_radius(
339-
self, sp: Union[Specie, Element], site: Optional[Site] = None
340-
) -> float:
340+
def get_radius(self, sp: Specie | Element, site: Site | None = None) -> float:
341341

342342
# allow manual override by user
343343
if site and "display_radius" in site.properties:
@@ -380,7 +380,7 @@ def get_radius(
380380
return radius
381381

382382
@staticmethod
383-
def analyze_site_props(site_collection: SiteCollection) -> Dict[str, List[str]]:
383+
def analyze_site_props(site_collection: SiteCollection) -> dict[str, list[str]]:
384384
"""
385385
Returns: A dictionary with keys "scalar", "matrix", "vector", "categorical"
386386
and values of a list of site property names corresponding to each type
@@ -400,7 +400,7 @@ def analyze_site_props(site_collection: SiteCollection) -> Dict[str, List[str]]:
400400
return dict(site_prop_names)
401401

402402
@staticmethod
403-
def get_species_str(sp: Union[Specie, Element]) -> str:
403+
def get_species_str(sp: Specie | Element) -> str:
404404
"""
405405
Args:
406406
sp: Specie or Element
@@ -411,7 +411,7 @@ def get_species_str(sp: Union[Specie, Element]) -> str:
411411
# and then move this to pymatgen string utils ...
412412
return unicodeify_species(str(sp))
413413

414-
def get_legend(self) -> Dict[str, Any]:
414+
def get_legend(self) -> dict[str, Any]:
415415

416416
# decide what we want the labels to be
417417
if self.color_scheme in ("Jmol", "VESTA", "accessible"):

0 commit comments

Comments
 (0)