Skip to content

Commit 07283be

Browse files
committed
ipython tests
1 parent df33ac0 commit 07283be

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

tests/asy_test/core/__init__.py

Whitespace-only changes.

tests/asy_test/core/test_jupyter.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import plotly.express as px
2+
from monty.json import MSONable
3+
4+
from crystal_toolkit.core.jupyter import patch_msonable
5+
from crystal_toolkit.core.scene import Scene
6+
7+
8+
def test_patch_msonable():
9+
patch_msonable()
10+
11+
class GetSceneClass(MSONable):
12+
def get_scene(self):
13+
return Scene(name="test_scene")
14+
15+
class GetPlotClass(MSONable):
16+
def get_plot(self):
17+
"""Dummy plotly object"""
18+
return px.scatter(x=[1, 2, 3], y=[1, 2, 3])
19+
20+
class AsDictClass(MSONable):
21+
def __init__(self, a: int) -> None:
22+
self.a = a
23+
24+
# The output of _ipython_display_ is None
25+
# However, the logic for the creating the different output
26+
# dictionaries should be executed so the following tests
27+
# are still valuable.
28+
as_dict_class = AsDictClass(1)
29+
assert as_dict_class._ipython_display_() is None
30+
31+
get_scene_class = GetSceneClass()
32+
assert get_scene_class._ipython_display_() is None
33+
34+
get_plot_class = GetPlotClass()
35+
assert get_plot_class._ipython_display_() is None

tests/asy_test/core/test_legend.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from __future__ import annotations
2+
3+
from pymatgen.core import Lattice, Structure
4+
5+
from crystal_toolkit.core.legend import Legend
6+
7+
8+
class TestLegend:
9+
def setup_method(self, method):
10+
self.struct = Structure(
11+
Lattice.cubic(5),
12+
["H", "O", "In"],
13+
[[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0, 0]],
14+
site_properties={
15+
"example_site_prop": [5, 0, -3],
16+
"example_categorical_site_prop": ["4a", "4a", "8b"],
17+
},
18+
)
19+
20+
self.site0 = self.struct[0]
21+
self.sp0 = next(iter(self.site0.species))
22+
23+
self.site1 = self.struct[1]
24+
self.sp1 = next(iter(self.site1.species))
25+
26+
self.site2 = self.struct[2]
27+
self.sp2 = next(iter(self.site2.species))
28+
29+
self.struct_disordered = Structure(
30+
Lattice.cubic(5),
31+
["H", "O", {"In": 0.5, "Al": 0.5}],
32+
[[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0, 0]],
33+
site_properties={"example_site_prop": [5, 0, -3]},
34+
)
35+
36+
self.site_d = self.struct_disordered[2]
37+
self.site_d_sp0 = next(iter(self.site_d.species))
38+
self.site_d_sp1 = list(self.site_d.species)[1]
39+
40+
self.struct_manual = Structure(
41+
Lattice.cubic(5),
42+
["H", "O2-", "In"],
43+
[[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0, 0]],
44+
site_properties={"display_color": [[255, 0, 0], "blue", "#00ff00"]},
45+
)
46+
47+
def test_get_color(self):
48+
# test default
49+
50+
legend = Legend(self.struct, color_scheme="VESTA")
51+
52+
color = legend.get_color(self.sp0)
53+
assert color == "#ffcccc"
54+
55+
# element-based color schemes shouldn't change if you supply a site
56+
color = legend.get_color(self.sp0, site=self.site0)
57+
assert color == "#ffcccc"
58+
59+
color = legend.get_color(self.sp2)
60+
assert color == "#a67573"
61+
62+
assert legend.get_legend()["colors"] == {
63+
"#a67573": "In",
64+
"#fe0300": "O",
65+
"#ffcccc": "H",
66+
}
67+
68+
# test alternate
69+
70+
legend = Legend(self.struct, color_scheme="Jmol")
71+
72+
color = legend.get_color(self.sp0)
73+
assert color == "#ffffff"
74+
75+
assert legend.get_legend()["colors"] == {
76+
"#a67573": "In",
77+
"#ff0d0d": "O",
78+
"#ffffff": "H",
79+
}
80+
81+
# test coloring by site properties
82+
83+
legend = Legend(self.struct, color_scheme="example_site_prop")
84+
85+
color = legend.get_color(self.sp0, site=self.site0)
86+
assert color == "#b30326"
87+
88+
color = legend.get_color(self.sp1, site=self.site1)
89+
assert color == "#000000"
90+
91+
color = legend.get_color(self.sp2, site=self.site2)
92+
assert color == "#7b9ef8"
93+
94+
assert legend.get_legend()["colors"] == {
95+
"#7b9ef8": "-3.00",
96+
"#b30326": "5.00",
97+
"#000000": "0.00",
98+
}
99+
100+
# test accessible
101+
102+
legend = Legend(self.struct, color_scheme="accessible")
103+
104+
color = legend.get_color(self.sp0, site=self.site0)
105+
assert color == "#ffffff"
106+
107+
color = legend.get_color(self.sp1, site=self.site1)
108+
assert color == "#d55e00"
109+
110+
color = legend.get_color(self.sp2, site=self.site2)
111+
assert color == "#cc79a7"
112+
113+
assert legend.get_legend()["colors"] == {
114+
"#cc79a7": "In",
115+
"#d55e00": "O",
116+
"#ffffff": "H",
117+
}
118+
119+
# test disordered
120+
121+
legend = Legend(self.struct_disordered)
122+
123+
color = legend.get_color(self.site_d_sp0, site=self.site_d)
124+
assert color == "#a67573"
125+
126+
color = legend.get_color(self.site_d_sp1, site=self.site_d)
127+
assert color == "#bfa6a6"
128+
129+
assert legend.get_legend()["colors"] == {
130+
"#a67573": "In",
131+
"#bfa6a6": "Al",
132+
"#ff0d0d": "O",
133+
"#ffffff": "H",
134+
}
135+
136+
# test categorical
137+
138+
legend = Legend(self.struct, color_scheme="example_categorical_site_prop")
139+
140+
assert legend.get_legend()["colors"] == {"#377eb8": "8b", "#e41a1c": "4a"}
141+
142+
# test pre-defined
143+
144+
legend = Legend(self.struct_manual)
145+
146+
assert legend.get_legend()["colors"] == {
147+
"#0000ff": "O²⁻",
148+
"#00ff00": "In",
149+
"#ff0000": "H",
150+
}
151+
152+
def test_get_radius(self):
153+
legend = Legend(self.struct, radius_scheme="uniform")
154+
155+
assert legend.get_radius(sp=self.sp0) == 0.5
156+
157+
legend = Legend(self.struct, radius_scheme="covalent")
158+
159+
assert legend.get_radius(sp=self.sp1) == 0.66
160+
161+
legend = Legend(self.struct, radius_scheme="specified_or_average_ionic")
162+
163+
assert legend.get_radius(sp=self.sp2) == 0.94
164+
165+
def test_msonable(self):
166+
legend = Legend(self.struct)
167+
legend_dict = legend.as_dict()
168+
legend_from_dict = Legend.from_dict(legend_dict)
169+
170+
assert legend.get_legend() == legend_from_dict.get_legend()
171+
172+
def test_get_tiling(self):
173+
scene = self.struct.get_scene()
174+
assert hasattr(scene, "lattice")
175+
assert scene.lattice == [[5.0, 0, 0], [0, 5.0, 0], [0, 0, 5.0]]

0 commit comments

Comments
 (0)