-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathchem_sys_treemap.py
166 lines (147 loc) · 5.51 KB
/
chem_sys_treemap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Plot treemap distribution of chemical systems."""
# %%
from __future__ import annotations
import pandas as pd
import plotly.express as px
import pymatviz as pmv
from pymatviz.enums import Key
from pymatviz.utils import ROOT
pmv.set_plotly_template("plotly_dark")
# %% Basic example with different group_by options and customizations
formulas = (
"Pb(Zr0.52Ti0.48)O3 La0.7Sr0.3MnO3 Li0.5Na0.5O LiNaO2 Li2O LiFeO2 " # noqa: SIM905
"LiFeO3 Al2O3 MgO".split()
)
for group_by in ("formula", "reduced_formula", "chem_sys"):
fig = pmv.chem_sys_treemap(formulas, group_by=group_by, show_counts="value+percent")
# Add customizations: rounded corners and custom hover info
fig.update_traces(
marker=dict(cornerradius=5),
hovertemplate="<b>%{label}</b><br>Count: %{value}<br>Percentage: "
"%{percentRoot:.1%} of total<extra></extra>",
)
title = f"Basic tree map grouped by {group_by} with rounded corners"
fig.layout.title = dict(text=title, x=0.5, y=0.8, font_size=18)
fig.show()
group_suffix = group_by.replace("_", "-").replace("chem_sys", "")
# pmv.io.save_and_compress_svg(fig, f"chem-sys-treemap-{group_suffix}")
# %% Load the Ward metallic glass dataset https://pubs.acs.org/doi/10.1021/acs.chemmater.6b04153
csv_path = f"{ROOT}/examples/ward_metallic_glasses/ward-metallic-glasses.csv.xz"
df_mg = pd.read_csv(csv_path, na_values=()).query("comment.isna()")
fig = pmv.chem_sys_treemap(
df_mg[Key.composition],
group_by="chem_sys",
show_counts="value+percent",
color_discrete_sequence=px.colors.qualitative.Set2,
)
# Add customizations: custom text display and root color
fig.update_traces(textinfo="label+value+percent entry", root_color="lightgrey")
title = "Ward Metallic Glass Dataset - With custom text display and root color"
fig.layout.title = dict(text=title, x=0.5, y=0.85, font_size=18)
fig.layout.update(height=500)
fig.show()
# %% Create a plot focusing on glass-forming ability (GFA) with patterns
for key, df_sub in df_mg.groupby("gfa_type"):
fig = pmv.chem_sys_treemap(
df_sub[Key.composition],
show_counts="value+percent",
color_discrete_sequence=px.colors.qualitative.Set2,
)
# Add customizations: patterns/textures and maximum depth
patterns = {
"unary": "|",
"binary": "/",
"ternary": "x",
"quaternary": "+",
"quinary": ".",
}
fig.update_traces(
maxdepth=2, # Limit depth for clarity
marker_pattern_shape=[
next((val for key, val in patterns.items() if key in parent), "")
for parent in fig.data[0].parents
],
)
title = f"Ward Metallic Glass Dataset - {key} Compositions<br>"
title += "with patterns and limited depth"
fig.layout.title = dict(text=title, x=0.5, y=0.8, font_size=18)
fig.show()
# pmv.io.save_and_compress_svg(fig, f"chem-sys-treemap-ward-bmg-{key.lower()}")
# %% Demonstrate the max_cells parameter with custom hover and rounded corners
fig = pmv.chem_sys_treemap(
df_mg[Key.composition],
group_by="chem_sys",
show_counts="value+percent",
max_cells=5, # Limit systems per arity
color_discrete_sequence=px.colors.qualitative.Set2,
)
# Add customizations: rounded corners and hover info
fig.update_traces(
marker=dict(cornerradius=8),
hovertemplate="<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percentRoot:.1%}"
" of total<extra></extra>",
)
title = "Ward Metallic Glass Dataset - Top 5 systems per arity with rounded corners"
fig.layout.title = dict(text=title, x=0.5, y=0.85, font_size=18)
fig.layout.update(height=500)
fig.show()
pmv.io.save_and_compress_svg(fig, "chem-sys-treemap-top-5")
# %% Custom color mapping with additional customizations
# Create a base treemap to customize
formulas = [
"Fe2O3", # binary
"Fe4O6", # same as Fe2O3 when group_by="reduced_formula"
"FeO", # different formula but same system when group_by="chem_sys"
"Li2O", # binary
"LiFeO2", # ternary
"Li3FeO3", # ternary (same system as LiFeO2)
"Al2O3", # binary
"MgO", # binary
"SiO2", # binary
]
fig = pmv.chem_sys_treemap(formulas)
# Create a custom color map for specific chemical systems
color_map = {
"Fe-O": "red",
"Li-O": "blue",
"Fe-Li-O": "purple",
"Al-O": "green",
"Mg-O": "orange",
"O-Si": "yellow",
}
# Initialize the colors array with None values
colors = [color_map.get(label) for label in fig.data[0].labels]
# Add multiple customizations: custom colors, text display, and rounded corners
fig.update_traces(
marker=dict(colors=colors, cornerradius=5),
textinfo="label+value+percent entry",
root_color="lightgrey",
)
fig.layout.title = dict(
text="Treemap with Custom Color Mapping, Text Display and Root Color",
x=0.5,
y=0.85,
font_size=18,
)
fig.show()
# %% Comprehensive example with multiple customizations
fig = pmv.chem_sys_treemap(
formulas,
color_discrete_sequence=px.colors.qualitative.Pastel,
)
# many customizations: patterns, hover info, rounded corners, and custom text
patterns = {"unary": "|", "binary": "/", "ternary": "x"}
fig.update_traces(
marker=dict(cornerradius=10),
marker_pattern_shape=[
next((val for key, val in patterns.items() if key in parent), "")
for parent in fig.data[0].parents
],
hovertemplate="<b>%{label}</b><br>Count: %{value}<br>Percentage: "
"%{percentRoot:.1%} of total<extra></extra>",
textinfo="label+value",
)
fig.layout.title = dict(
text="Comprehensive Customization Example", x=0.5, y=0.85, font_size=18
)
fig.show()