Skip to content

Commit 375c4f1

Browse files
committed
bug fix
1 parent bfaf7ae commit 375c4f1

File tree

1 file changed

+383
-0
lines changed

1 file changed

+383
-0
lines changed

plotting.py

+383
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
#!/usr/bin/python3
2+
3+
import sys
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
from matplotlib.collections import LineCollection
8+
from matplotlib.lines import Line2D
9+
from scipy.interpolate import griddata
10+
11+
import utilities
12+
from readvasp import from_poscar
13+
14+
15+
def plot_band_dos(KPT, eigdata, dosdata, config, axs=None, *args, **kwargs):
16+
ispin = eigdata.eig.shape[0]
17+
if axs is None:
18+
_, axs = plt.subplots(1,
19+
len(config["groupinfo"])*ispin + 1,
20+
sharey=True)
21+
for iax in axs[1:-1]:
22+
iax.sharex(axs[0])
23+
plot_band(KPT, eigdata, config, axs=axs[:-1], *args, **kwargs)
24+
plot_dos(dosdata, config, swap=True, ax=axs[-1], *args, **kwargs)
25+
[i.set_ylabel('') for i in axs[1:]]
26+
27+
28+
def plot_band(KPT, data, config, axs=None, *args, **kwargs):
29+
ispin = data.eig.shape[0]
30+
if not "f" in config["plotset"]["fermi"].lower():
31+
try:
32+
fermi = utilities.get_energy(data.eig,
33+
data.occ,
34+
info=config["plotset"]["fermi"])
35+
except:
36+
print("unknown type ,set fermi to 0eV")
37+
fermi = data.fermi
38+
config["plotset"]["fermi"] = "%s" % fermi
39+
eigall = data.eig - fermi
40+
else:
41+
eigall = data.eig
42+
used_points=[not "!" in i for i in KPT.k_sp_label]
43+
kpath=KPT.kpath[used_points]
44+
eigall=eigall[:,used_points]
45+
k_sp_label=[i for i in KPT.k_sp_label if not "!" in i]
46+
if axs is None:
47+
_, axs = plt.subplots(1,
48+
len(config["groupinfo"])*ispin,
49+
sharey=True, sharex=True)
50+
if type(axs).__name__ == "AxesSubplot":
51+
axs = np.array([axs])
52+
53+
for iaxs, group in zip(axs.reshape(-1, ispin), config["groupinfo"]):
54+
for i in range(ispin):
55+
ax = iaxs[i]
56+
e_range = np.where((eigall[i].max(axis=0) > config["plotset"]["Elim"][0]) & (
57+
eigall[i].min(axis=0) < config["plotset"]["Elim"][1]))[0]
58+
eig = eigall[:, :, e_range]
59+
if config["plotset"]["plot_type"] == 0:
60+
sc = plot_band_type_0(
61+
kpath, eig[i], ax=ax, c='black', *args, **kwargs)
62+
elif config["plotset"]["plot_type"] >= 1:
63+
pos = from_poscar.get_poscar(config["fileset"]["posfile"])
64+
tag, group_info = utilities.group_info_from_input(group)
65+
project_group = data.set_group(
66+
grouptag=tag, symbollist=pos.get_symbollist())
67+
project_group = project_group[:, :, :, e_range]
68+
if config["plotset"]["plot_type"] == 1:
69+
sc = plot_band_type_1(kpath, eig[i],
70+
project_group=project_group[i],
71+
group_info=group_info,
72+
ax=ax,
73+
int=config["plotset"]["int"],
74+
scale=config["plotset"]["scale"],
75+
*args,
76+
**kwargs)
77+
if i == (len(range(ispin))-1):
78+
legend_elements = [
79+
Line2D([0], [0], color=j, marker='o', label=i) for i, j in group_info]
80+
ax.legend(handles=legend_elements, **config["paras"]["band_legend_paras"])
81+
elif config["plotset"]["plot_type"] == 2:
82+
ax = iaxs[i]
83+
sc = plot_band_type_2(kpath,
84+
eig[i],
85+
project_group=project_group[i],
86+
group_info=group_info,
87+
ax=ax,
88+
*args,
89+
**kwargs)
90+
if i == range(ispin)[-1]:
91+
cb = plt.colorbar(sc, ax=ax)
92+
cb.set_ticks([0, 1])
93+
if len(group_info) == 1:
94+
cb.set_ticklabels([r'', r'' + group_info[0][0]])
95+
else:
96+
cb.set_ticklabels(
97+
[r'' + group_info[1][0], r'' + group_info[0][0]])
98+
else:
99+
print("unkown type for plotting, available value is 0,1,2")
100+
sys.exit()
101+
for ax in axs:
102+
plot_sp_kline(kpath, k_sp_label, ax=ax, *args, **kwargs)
103+
ax.set_ylim(config["plotset"]["Elim"])
104+
ax.set_xlim(kpath.min(), kpath.max())
105+
if config["plotset"]["fermi"]:
106+
ax.set_ylabel(r'$E\ -\ E_\mathrm{f}\ \mathrm{(eV)}$')
107+
else:
108+
ax.set_ylabel(r'$E\ \mathrm{eV}$')
109+
110+
111+
def plot_dos(data, config, swap=False, ax=None, *args, **kwargs):
112+
ispin = data.total.shape[0]
113+
dos = data.total
114+
type = config["plotset"]["plot_type"]
115+
if not "f" in config["plotset"]["fermi"].lower():
116+
if config["plotset"]["fermi"].lower() == "vbm":
117+
fermi = data.fermi
118+
elif "band" in config["plotset"]["fermi"].lower():
119+
print("DOS only supports setting fermi or a specify energy to 0eV")
120+
fermi = data.fermi
121+
else:
122+
fermi = utilities.get_energy(info=config["plotset"]["fermi"])
123+
if fermi is None:
124+
print("DOS only supports setting fermi or a specify energy to 0eV")
125+
fermi = data.fermi
126+
eig = data.eig - fermi
127+
print("Fermi is set to %s" % fermi)
128+
else:
129+
eig = data.eig
130+
e_range = np.where((config["plotset"]["Elim"][0] < eig)
131+
& (eig < config["plotset"]["Elim"][1]))[0]
132+
# print(0,e_range[0]-1)
133+
e_range = [max(0, e_range[0]-1)] + e_range.tolist() + \
134+
[min(e_range[-1]+1, eig.size - 1)]
135+
eig = eig[e_range]
136+
dos = dos[:, e_range]
137+
if ax is None:
138+
_, ax = plt.subplots()
139+
tag = []
140+
info = []
141+
if type > 0:
142+
for group in config["groupinfo"]:
143+
t_tag, t_info = utilities.group_info_from_input(group)
144+
tag.extend(t_tag)
145+
info.extend(t_info)
146+
147+
for i in range(ispin):
148+
sign = -np.sign(i - 0.5)
149+
if type == 0:
150+
plot_dos_type_0(eig, sign*dos[i], swap=swap, ax=ax, *args, **kwargs)
151+
elif type > 0:
152+
if swap:
153+
ax.fill_betweenx(
154+
eig, 0, sign*dos[i], color='gray', alpha=.25, label="Total")
155+
else:
156+
ax.fill_between(eig, sign*dos[i], color='gray',
157+
alpha=.25, label="Total")
158+
pos = from_poscar.get_poscar(config["fileset"]["posfile"])
159+
proj = data.set_group(
160+
grouptag=tag, symbollist=pos.get_symbollist())
161+
proj = proj[:, :, e_range]
162+
plot_dos_type_1(eig,
163+
dos[i],
164+
sign * proj[i],
165+
info,
166+
swap=swap,
167+
ax=ax,
168+
*args,
169+
**kwargs)
170+
else:
171+
print("unkwon type for dos")
172+
sys.exit()
173+
174+
if type > 0:
175+
xlabel = "PDOS"
176+
legend_elements = [Line2D([0], [0], color="gray", lw=2, label="Total")] + \
177+
[Line2D([0], [0], color=j, lw=2, label=i) for i, j in info]
178+
ax.legend(handles=legend_elements, **config["paras"]["dos_legend_paras"])
179+
doslim = 1.1 * np.nanmax(proj[:, :, ])
180+
else:
181+
xlabel = "DOS"
182+
doslim = 1.1 * np.nanmax(dos[:, ])
183+
if swap:
184+
ax.set_xlabel(xlabel)
185+
if config["plotset"]["fermi"]:
186+
ax.set_ylabel(r'$\mathrm{E\ -\ E}_f\ \mathrm{(eV)}$')
187+
else:
188+
ax.set_ylabel(r'$\mathrm{E}\ \mathrm{(eV)}$')
189+
ax.set_ylim(config["plotset"]["Elim"])
190+
ax.set_xlim([-doslim*i, doslim])
191+
else:
192+
ax.set_ylabel(xlabel)
193+
if config["plotset"]["fermi"]:
194+
ax.set_xlabel(r'$\mathrm{E\ -\ E}_f\ \mathrm{(eV)}$')
195+
else:
196+
ax.set_xlabel(r'$\mathrm{E}\ \mathrm{(eV)}$')
197+
ax.set_xlim(config["plotset"]["Elim"])
198+
ax.set_ylim([-doslim*i, doslim])
199+
200+
201+
def plot_sp_kline(kpath, k_sp_label, ax=None, *args, **kwargs):
202+
if ax is None:
203+
ax = plt.subplot()
204+
# "!" is marked for the scf part in HSE/metaGGA calculate and will be exclude in the plot
205+
kpath_sp = kpath[[bool(i) for i in k_sp_label]]
206+
ax.set_xticklabels([])
207+
if (len(kpath_sp) > 0):
208+
if (len(kpath) > 1):
209+
ax.vlines(kpath_sp, -1000, 1000, lw=0.5, color='black', *args, **kwargs,alpha=0.2)
210+
ax.set_xticks(kpath_sp)
211+
ax.set_xticklabels([i for i in k_sp_label if i])
212+
213+
214+
def plot_band_type_0(kpath, eig, ax=None, *args, **kwargs):
215+
if ax is None:
216+
ax = plt.subplot()
217+
if kpath.shape[0] == 1:
218+
if 'c' in kwargs.keys():
219+
kwargs['color'] = kwargs.pop('c')
220+
ax.hlines(eig, -0.5, 0.5, *args, **kwargs)
221+
else:
222+
ax.plot(kpath, eig, *args, **kwargs)
223+
224+
225+
def plot_band_type_1(kpath,
226+
eig,
227+
project_group,
228+
group_info,
229+
ax=None,
230+
int=100,
231+
scale=50,
232+
*args,
233+
**kwargs):
234+
if ax is None:
235+
ax = plt.subplot()
236+
project_group = np.cumsum(project_group[::-1], axis=0)[::-1]
237+
plot_band_type_0(kpath, eig, ax=ax, color='gray', *args, **kwargs)
238+
if kpath.shape[0] == 1:
239+
kpath = np.array([-0.5, 0.5])
240+
eig = eig.repeat(2, 0)
241+
project_group = project_group.repeat(2, 1)
242+
# print(eig.shape, project_group.shape)
243+
x = np.linspace(kpath.min(), kpath.max(), int)
244+
y = griddata(kpath, eig, x)
245+
for i, igroup in enumerate(project_group):
246+
label = r'' + group_info[i][0]
247+
Nband = eig.shape[1]
248+
Y = y
249+
X = x.repeat(Nband)
250+
s = griddata(kpath, igroup, x)
251+
ax.scatter(X,
252+
Y,
253+
c=group_info[i][1],
254+
s=s * scale,
255+
linewidths=None,
256+
label=label, zorder=-10)
257+
258+
259+
def plot_band_type_2(kpath,
260+
eig,
261+
project_group,
262+
group_info,
263+
ax=None,
264+
size=1,
265+
*args,
266+
**kwargs):
267+
if ax is None:
268+
ax = plt.subplot()
269+
norm = plt.Normalize(0, 1)
270+
for ieig in range(eig.shape[1]):
271+
x = kpath.repeat(2)[:-1]
272+
x[1::2] += (x[2::2] - x[1::2]) / 2
273+
y = eig[:, ieig].repeat(2)[:-1]
274+
y[1::2] += (y[2::2] - y[1::2]) / 2
275+
c = project_group[0, :, ieig].repeat(2)[:-1]
276+
c[1::2] += (c[2::2] - c[1::2]) / 2
277+
points = np.array([x, y]).T.reshape(-1, 1, 2)
278+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
279+
lc = LineCollection(segments, cmap="rainbow", norm=norm)
280+
lc.set_array(c)
281+
line = ax.add_collection(lc)
282+
return line
283+
284+
285+
def plot_dos_astype(type,
286+
eig,
287+
dos,
288+
project_group,
289+
group_info,
290+
swap=False,
291+
ax=None,
292+
*args,
293+
**kwargs):
294+
if ax is None:
295+
ax = plt.subplot()
296+
if type == 0:
297+
plot_dos_type_0(eig, dos, swap=swap, ax=ax, *args, **kwargs)
298+
elif type > 0:
299+
plot_dos_type_1(eig,
300+
dos,
301+
project_group,
302+
group_info,
303+
swap=swap,
304+
ax=ax,
305+
*args,
306+
**kwargs)
307+
else:
308+
print("unkwon type for dos")
309+
sys.exit()
310+
311+
312+
def plot_dos_type_0(eig, dos, swap=False, ax=None, *args, **kwargs):
313+
if ax is None:
314+
ax = plt.subplot()
315+
if swap:
316+
y, x = eig, dos
317+
else:
318+
x, y = eig, dos
319+
ax.plot(x, y, c='black', zorder=-10, *args, **kwargs)
320+
321+
322+
def plot_dos_type_1(eig,
323+
dos,
324+
project_group,
325+
group_info,
326+
swap=False,
327+
ax=None,
328+
*args,
329+
**kwargs):
330+
if ax is None:
331+
ax = plt.subplot()
332+
for i, igroup in enumerate(project_group):
333+
label = r'' + group_info[i][0]
334+
if swap:
335+
ax.plot(igroup, eig, c=group_info[i]
336+
[1], zorder=-10, *args, **kwargs)
337+
ax.fill_betweenx(eig,
338+
0,
339+
igroup,
340+
color=group_info[i][1],
341+
alpha=.25,
342+
label=label,
343+
*args,
344+
**kwargs)
345+
else:
346+
ax.plot(eig, igroup, c=group_info[i]
347+
[1], zorder=-10, *args, **kwargs)
348+
ax.fill_between(eig,
349+
igroup,
350+
color=group_info[i][1],
351+
alpha=.25,
352+
label=label,
353+
*args,
354+
**kwargs)
355+
356+
357+
def line_average_of_chg(chg, abc_axe, ax=None):
358+
"""
359+
average alone a axis
360+
_axe =
361+
'a','b','c' alone the lattice constant
362+
//'x','y','z' alone the x,y,z direction
363+
"""
364+
if ax is None:
365+
ax = plt.subplot()
366+
_axe = ['c', 'b', 'a'].index(abc_axe)
367+
# _axe=int(sys.argv[1])-1
368+
axe = [0, 1, 2]
369+
axe.remove(_axe)
370+
axe = tuple(axe)
371+
NG = chg.shape
372+
ax.plot(range(NG[_axe]) / NG[_axe] *
373+
np.linalg.norm(chg.cell[(2, 1, 0)[_axe], :]),
374+
np.average(chg.chg, axis=axe),
375+
color='black', zorder=-10)
376+
ax.set_xlim([
377+
np.min(
378+
range(NG[_axe]) / NG[_axe] *
379+
np.linalg.norm(chg.cell[(2, 1, 0)[_axe], :])),
380+
np.max(
381+
range(NG[_axe]) / NG[_axe] *
382+
np.linalg.norm(chg.cell[(2, 1, 0)[_axe], :]))
383+
])

0 commit comments

Comments
 (0)