@@ -37,8 +37,10 @@ def plot_band(KPT, data, config, axs=None, *args, **kwargs):
37
37
eigall = data .eig - fermi
38
38
else :
39
39
eigall = data .eig
40
-
41
-
40
+ used_points = [not "!" in i for i in KPT .k_sp_label ]
41
+ kpath = KPT .kpath [used_points ]
42
+ eigall = eigall [:,used_points ]
43
+ k_sp_label = [i for i in KPT .k_sp_label if not "!" in i ]
42
44
if axs is None :
43
45
_ , axs = plt .subplots (1 ,
44
46
len (config ["groupinfo" ])* ispin ,
@@ -48,42 +50,43 @@ def plot_band(KPT, data, config, axs=None, *args, **kwargs):
48
50
49
51
for iaxs , group in zip (axs .reshape (- 1 , ispin ), config ["groupinfo" ]):
50
52
for i in range (ispin ):
51
- ax = iaxs [i ]
52
- e_range = np .where ((eigall [i ].max (axis = 0 )> config ["plotset" ]["Elim" ][0 ])& (eigall [i ].min (axis = 0 )< config ["plotset" ]["Elim" ][1 ]))[0 ]
53
- eig = eigall [:,:,e_range ]
53
+ ax = iaxs [i ]
54
+ e_range = np .where ((eigall [i ].max (axis = 0 ) > config ["plotset" ]["Elim" ][0 ]) & (
55
+ eigall [i ].min (axis = 0 ) < config ["plotset" ]["Elim" ][1 ]))[0 ]
56
+ eig = eigall [:, :, e_range ]
54
57
if config ["plotset" ]["plot_type" ] == 0 :
55
58
sc = plot_band_type_0 (
56
- KPT . kpath , eig [i ], ax = ax , c = 'black' , * args , ** kwargs )
57
- elif config ["plotset" ]["plot_type" ]>= 1 :
59
+ kpath , eig [i ], ax = ax , c = 'black' , * args , ** kwargs )
60
+ elif config ["plotset" ]["plot_type" ] >= 1 :
58
61
pos = from_poscar .get_poscar (config ["fileset" ]["posfile" ])
59
62
tag , group_info = ultils .group_info_from_input (group )
60
63
project_group = data .set_group (
61
64
grouptag = tag , symbollist = pos .get_symbollist ())
62
- project_group = project_group [:,:,:, e_range ]
65
+ project_group = project_group [:, :, :, e_range ]
63
66
print (config ["plotset" ])
64
67
if config ["plotset" ]["plot_type" ] == 1 :
65
- sc = plot_band_type_1 (KPT . kpath , eig [i ],
66
- project_group = project_group [i ],
67
- group_info = group_info ,
68
- ax = ax ,
69
- int = config ["plotset" ]["int" ],
70
- scale = config ["plotset" ]["scale" ],
71
- * args ,
72
- ** kwargs )
73
- if i == (len (range (ispin ))- 1 ):
68
+ sc = plot_band_type_1 (kpath , eig [i ],
69
+ project_group = project_group [i ],
70
+ group_info = group_info ,
71
+ ax = ax ,
72
+ int = config ["plotset" ]["int" ],
73
+ scale = config ["plotset" ]["scale" ],
74
+ * args ,
75
+ ** kwargs )
76
+ if i == (len (range (ispin ))- 1 ):
74
77
legend_elements = [
75
78
Line2D ([0 ], [0 ], color = j , marker = 'o' , label = i ) for i , j in group_info ]
76
79
ax .legend (handles = legend_elements , loc = "upper right" )
77
80
elif config ["plotset" ]["plot_type" ] == 2 :
78
81
ax = iaxs [i ]
79
- sc = plot_band_type_2 (KPT . kpath ,
80
- eig [i ],
81
- project_group = project_group [i ],
82
- group_info = group_info ,
83
- ax = ax ,
84
- * args ,
85
- ** kwargs )
86
- if i == range (ispin )[- 1 ]:
82
+ sc = plot_band_type_2 (kpath ,
83
+ eig [i ],
84
+ project_group = project_group [i ],
85
+ group_info = group_info ,
86
+ ax = ax ,
87
+ * args ,
88
+ ** kwargs )
89
+ if i == range (ispin )[- 1 ]:
87
90
cb = plt .colorbar (sc , ax = axs )
88
91
cb .set_ticks ([0 , 1 ])
89
92
if len (group_info ) == 1 :
@@ -95,9 +98,9 @@ def plot_band(KPT, data, config, axs=None, *args, **kwargs):
95
98
print ("unkown type for plotting, available value is 0,1,2" )
96
99
sys .exit ()
97
100
for ax in axs :
98
- plot_sp_kline (KPT . kpath , KPT . k_sp_label , ax = ax , * args , ** kwargs )
101
+ plot_sp_kline (kpath , k_sp_label , ax = ax , * args , ** kwargs )
99
102
ax .set_ylim (config ["plotset" ]["Elim" ])
100
- ax .set_xlim (KPT . kpath .min (), KPT . kpath .max ())
103
+ ax .set_xlim (kpath .min (), kpath .max ())
101
104
if config ["plotset" ]["fermi" ]:
102
105
ax .set_ylabel (r'$E\ -\ E_f\ \mathrm{(eV)}$' )
103
106
else :
@@ -123,9 +126,10 @@ def plot_dos(data, config, swap=False, ax=None, *args, **kwargs):
123
126
e_range = np .where ((config ["plotset" ]["Elim" ][0 ] < eig )
124
127
& (eig < config ["plotset" ]["Elim" ][1 ]))[0 ]
125
128
# print(0,e_range[0]-1)
126
- e_range = [max (0 ,e_range [0 ]- 1 )] + e_range .tolist () + [min (e_range [- 1 ]+ 1 ,eig .size - 1 )]
127
- eig = eig [e_range ]
128
- dos = dos [:,e_range ]
129
+ e_range = [max (0 , e_range [0 ]- 1 )] + e_range .tolist () + \
130
+ [min (e_range [- 1 ]+ 1 , eig .size - 1 )]
131
+ eig = eig [e_range ]
132
+ dos = dos [:, e_range ]
129
133
if ax is None :
130
134
_ , ax = plt .subplots ()
131
135
tag = []
@@ -150,8 +154,8 @@ def plot_dos(data, config, swap=False, ax=None, *args, **kwargs):
150
154
pos = from_poscar .get_poscar (config ["fileset" ]["posfile" ])
151
155
proj = data .set_group (
152
156
grouptag = tag , symbollist = pos .get_symbollist ())
153
- proj = proj [:,:, e_range ]
154
- plot_dos_type_1 (eig ,
157
+ proj = proj [:, :, e_range ]
158
+ plot_dos_type_1 (eig ,
155
159
dos [i ],
156
160
sign * proj [i ],
157
161
info ,
@@ -194,15 +198,11 @@ def plot_sp_kline(kpath, k_sp_label, ax=None, *args, **kwargs):
194
198
if ax is None :
195
199
ax = plt .subplot ()
196
200
# "!" is marked for the scf part in HSE/metaGGA calculate and will be exclude in the plot
197
- kpath_sp = kpath [[i for i , j in enumerate ( k_sp_label ) if j ]]
201
+ kpath_sp = kpath [[bool ( i ) for i in k_sp_label ]]
198
202
ax .set_xticklabels ([])
199
203
if (len (kpath_sp ) > 0 ):
200
204
if (len (kpath ) > 1 ):
201
- ax .vlines (kpath_sp ,
202
- - 1000 , 1000 ,
203
- color = 'black' ,
204
- * args ,
205
- ** kwargs )
205
+ ax .vlines (kpath_sp , - 1000 , 1000 , color = 'black' , * args , ** kwargs )
206
206
ax .set_xticks (kpath_sp )
207
207
ax .set_xticklabels ([i for i in k_sp_label if i ])
208
208
@@ -212,8 +212,8 @@ def plot_band_type_0(kpath, eig, ax=None, *args, **kwargs):
212
212
ax = plt .subplot ()
213
213
if kpath .shape [0 ] == 1 :
214
214
if 'c' in kwargs .keys ():
215
- kwargs ['color' ]= kwargs .pop ('c' )
216
- ax .hlines (eig ,- 0.5 ,0.5 , * args , ** kwargs )
215
+ kwargs ['color' ] = kwargs .pop ('c' )
216
+ ax .hlines (eig , - 0.5 , 0.5 , * args , ** kwargs )
217
217
else :
218
218
ax .plot (kpath , eig , * args , ** kwargs )
219
219
@@ -235,10 +235,10 @@ def plot_band_type_1(kpath,
235
235
# for i, igroup in enumerate(project_group):
236
236
# label = r'' + group_info[i][0]
237
237
# ax.hlines(eig,-0.5,0.5,color=group_info[i][1],lw=project_group,label=label)
238
- kpath = np .array ([- 0.5 ,0.5 ])
239
- eig = eig .repeat (2 ,0 )
240
- project_group = project_group .repeat (2 ,1 )
241
- print (eig .shape ,project_group .shape )
238
+ kpath = np .array ([- 0.5 , 0.5 ])
239
+ eig = eig .repeat (2 , 0 )
240
+ project_group = project_group .repeat (2 , 1 )
241
+ print (eig .shape , project_group .shape )
242
242
else :
243
243
plot_band_type_0 (kpath , eig , ax = ax , c = 'gray' , * args , ** kwargs )
244
244
x = np .linspace (kpath .min (), kpath .max (), int )
@@ -250,11 +250,11 @@ def plot_band_type_1(kpath,
250
250
X = x .repeat (Nband )
251
251
s = griddata (kpath , igroup , x )
252
252
ax .scatter (X ,
253
- Y ,
254
- c = group_info [i ][1 ],
255
- s = s * scale ,
256
- linewidths = None ,
257
- label = label , zorder = - 10 )
253
+ Y ,
254
+ c = group_info [i ][1 ],
255
+ s = s * scale ,
256
+ linewidths = None ,
257
+ label = label , zorder = - 10 )
258
258
259
259
260
260
def plot_band_type_2 (kpath ,
@@ -280,18 +280,6 @@ def plot_band_type_2(kpath,
280
280
lc = LineCollection (segments , cmap = "rainbow" , norm = norm )
281
281
lc .set_array (c )
282
282
line = ax .add_collection (lc )
283
- # plt.plot(kpath[i:i+2],eig[i:i+2,j],c=col[C[i,j]])
284
- # sc = ax.scatter(X,
285
- # Y,
286
- # c=C,
287
- # vmin=0,
288
- # vmax=1,
289
- # s=size,
290
- # linewidths=None,
291
- # marker='o',
292
- # cmap="hsv",
293
- # *args,
294
- # **kwargs)
295
283
return line
296
284
297
285
0 commit comments