@@ -37,8 +37,10 @@ def plot_band(KPT, data, config, axs=None, *args, **kwargs):
3737 eigall = data .eig - fermi
3838 else :
3939 eigall = data .eig
40-
41-
40+ used_points = [not "!" in i for i in KPT .k_sp_label ]
41+ kpath = KPT .kpath [used_points ]
42+ eigall = eigall [:,used_points ]
43+ k_sp_label = [i for i in KPT .k_sp_label if not "!" in i ]
4244 if axs is None :
4345 _ , axs = plt .subplots (1 ,
4446 len (config ["groupinfo" ])* ispin ,
@@ -48,42 +50,43 @@ def plot_band(KPT, data, config, axs=None, *args, **kwargs):
4850
4951 for iaxs , group in zip (axs .reshape (- 1 , ispin ), config ["groupinfo" ]):
5052 for i in range (ispin ):
51- ax = iaxs [i ]
52- e_range = np .where ((eigall [i ].max (axis = 0 )> config ["plotset" ]["Elim" ][0 ])& (eigall [i ].min (axis = 0 )< config ["plotset" ]["Elim" ][1 ]))[0 ]
53- eig = eigall [:,:,e_range ]
53+ ax = iaxs [i ]
54+ e_range = np .where ((eigall [i ].max (axis = 0 ) > config ["plotset" ]["Elim" ][0 ]) & (
55+ eigall [i ].min (axis = 0 ) < config ["plotset" ]["Elim" ][1 ]))[0 ]
56+ eig = eigall [:, :, e_range ]
5457 if config ["plotset" ]["plot_type" ] == 0 :
5558 sc = plot_band_type_0 (
56- KPT . kpath , eig [i ], ax = ax , c = 'black' , * args , ** kwargs )
57- elif config ["plotset" ]["plot_type" ]>= 1 :
59+ kpath , eig [i ], ax = ax , c = 'black' , * args , ** kwargs )
60+ elif config ["plotset" ]["plot_type" ] >= 1 :
5861 pos = from_poscar .get_poscar (config ["fileset" ]["posfile" ])
5962 tag , group_info = ultils .group_info_from_input (group )
6063 project_group = data .set_group (
6164 grouptag = tag , symbollist = pos .get_symbollist ())
62- project_group = project_group [:,:,:, e_range ]
65+ project_group = project_group [:, :, :, e_range ]
6366 print (config ["plotset" ])
6467 if config ["plotset" ]["plot_type" ] == 1 :
65- sc = plot_band_type_1 (KPT . kpath , eig [i ],
66- project_group = project_group [i ],
67- group_info = group_info ,
68- ax = ax ,
69- int = config ["plotset" ]["int" ],
70- scale = config ["plotset" ]["scale" ],
71- * args ,
72- ** kwargs )
73- if i == (len (range (ispin ))- 1 ):
68+ sc = plot_band_type_1 (kpath , eig [i ],
69+ project_group = project_group [i ],
70+ group_info = group_info ,
71+ ax = ax ,
72+ int = config ["plotset" ]["int" ],
73+ scale = config ["plotset" ]["scale" ],
74+ * args ,
75+ ** kwargs )
76+ if i == (len (range (ispin ))- 1 ):
7477 legend_elements = [
7578 Line2D ([0 ], [0 ], color = j , marker = 'o' , label = i ) for i , j in group_info ]
7679 ax .legend (handles = legend_elements , loc = "upper right" )
7780 elif config ["plotset" ]["plot_type" ] == 2 :
7881 ax = iaxs [i ]
79- sc = plot_band_type_2 (KPT . kpath ,
80- eig [i ],
81- project_group = project_group [i ],
82- group_info = group_info ,
83- ax = ax ,
84- * args ,
85- ** kwargs )
86- if i == range (ispin )[- 1 ]:
82+ sc = plot_band_type_2 (kpath ,
83+ eig [i ],
84+ project_group = project_group [i ],
85+ group_info = group_info ,
86+ ax = ax ,
87+ * args ,
88+ ** kwargs )
89+ if i == range (ispin )[- 1 ]:
8790 cb = plt .colorbar (sc , ax = axs )
8891 cb .set_ticks ([0 , 1 ])
8992 if len (group_info ) == 1 :
@@ -95,9 +98,9 @@ def plot_band(KPT, data, config, axs=None, *args, **kwargs):
9598 print ("unkown type for plotting, available value is 0,1,2" )
9699 sys .exit ()
97100 for ax in axs :
98- plot_sp_kline (KPT . kpath , KPT . k_sp_label , ax = ax , * args , ** kwargs )
101+ plot_sp_kline (kpath , k_sp_label , ax = ax , * args , ** kwargs )
99102 ax .set_ylim (config ["plotset" ]["Elim" ])
100- ax .set_xlim (KPT . kpath .min (), KPT . kpath .max ())
103+ ax .set_xlim (kpath .min (), kpath .max ())
101104 if config ["plotset" ]["fermi" ]:
102105 ax .set_ylabel (r'$E\ -\ E_f\ \mathrm{(eV)}$' )
103106 else :
@@ -123,9 +126,10 @@ def plot_dos(data, config, swap=False, ax=None, *args, **kwargs):
123126 e_range = np .where ((config ["plotset" ]["Elim" ][0 ] < eig )
124127 & (eig < config ["plotset" ]["Elim" ][1 ]))[0 ]
125128 # print(0,e_range[0]-1)
126- e_range = [max (0 ,e_range [0 ]- 1 )] + e_range .tolist () + [min (e_range [- 1 ]+ 1 ,eig .size - 1 )]
127- eig = eig [e_range ]
128- dos = dos [:,e_range ]
129+ e_range = [max (0 , e_range [0 ]- 1 )] + e_range .tolist () + \
130+ [min (e_range [- 1 ]+ 1 , eig .size - 1 )]
131+ eig = eig [e_range ]
132+ dos = dos [:, e_range ]
129133 if ax is None :
130134 _ , ax = plt .subplots ()
131135 tag = []
@@ -150,8 +154,8 @@ def plot_dos(data, config, swap=False, ax=None, *args, **kwargs):
150154 pos = from_poscar .get_poscar (config ["fileset" ]["posfile" ])
151155 proj = data .set_group (
152156 grouptag = tag , symbollist = pos .get_symbollist ())
153- proj = proj [:,:, e_range ]
154- plot_dos_type_1 (eig ,
157+ proj = proj [:, :, e_range ]
158+ plot_dos_type_1 (eig ,
155159 dos [i ],
156160 sign * proj [i ],
157161 info ,
@@ -194,15 +198,11 @@ def plot_sp_kline(kpath, k_sp_label, ax=None, *args, **kwargs):
194198 if ax is None :
195199 ax = plt .subplot ()
196200 # "!" is marked for the scf part in HSE/metaGGA calculate and will be exclude in the plot
197- kpath_sp = kpath [[i for i , j in enumerate ( k_sp_label ) if j ]]
201+ kpath_sp = kpath [[bool ( i ) for i in k_sp_label ]]
198202 ax .set_xticklabels ([])
199203 if (len (kpath_sp ) > 0 ):
200204 if (len (kpath ) > 1 ):
201- ax .vlines (kpath_sp ,
202- - 1000 , 1000 ,
203- color = 'black' ,
204- * args ,
205- ** kwargs )
205+ ax .vlines (kpath_sp , - 1000 , 1000 , color = 'black' , * args , ** kwargs )
206206 ax .set_xticks (kpath_sp )
207207 ax .set_xticklabels ([i for i in k_sp_label if i ])
208208
@@ -212,8 +212,8 @@ def plot_band_type_0(kpath, eig, ax=None, *args, **kwargs):
212212 ax = plt .subplot ()
213213 if kpath .shape [0 ] == 1 :
214214 if 'c' in kwargs .keys ():
215- kwargs ['color' ]= kwargs .pop ('c' )
216- ax .hlines (eig ,- 0.5 ,0.5 , * args , ** kwargs )
215+ kwargs ['color' ] = kwargs .pop ('c' )
216+ ax .hlines (eig , - 0.5 , 0.5 , * args , ** kwargs )
217217 else :
218218 ax .plot (kpath , eig , * args , ** kwargs )
219219
@@ -235,10 +235,10 @@ def plot_band_type_1(kpath,
235235 # for i, igroup in enumerate(project_group):
236236 # label = r'' + group_info[i][0]
237237 # ax.hlines(eig,-0.5,0.5,color=group_info[i][1],lw=project_group,label=label)
238- kpath = np .array ([- 0.5 ,0.5 ])
239- eig = eig .repeat (2 ,0 )
240- project_group = project_group .repeat (2 ,1 )
241- print (eig .shape ,project_group .shape )
238+ kpath = np .array ([- 0.5 , 0.5 ])
239+ eig = eig .repeat (2 , 0 )
240+ project_group = project_group .repeat (2 , 1 )
241+ print (eig .shape , project_group .shape )
242242 else :
243243 plot_band_type_0 (kpath , eig , ax = ax , c = 'gray' , * args , ** kwargs )
244244 x = np .linspace (kpath .min (), kpath .max (), int )
@@ -250,11 +250,11 @@ def plot_band_type_1(kpath,
250250 X = x .repeat (Nband )
251251 s = griddata (kpath , igroup , x )
252252 ax .scatter (X ,
253- Y ,
254- c = group_info [i ][1 ],
255- s = s * scale ,
256- linewidths = None ,
257- label = label , zorder = - 10 )
253+ Y ,
254+ c = group_info [i ][1 ],
255+ s = s * scale ,
256+ linewidths = None ,
257+ label = label , zorder = - 10 )
258258
259259
260260def plot_band_type_2 (kpath ,
@@ -280,18 +280,6 @@ def plot_band_type_2(kpath,
280280 lc = LineCollection (segments , cmap = "rainbow" , norm = norm )
281281 lc .set_array (c )
282282 line = ax .add_collection (lc )
283- # plt.plot(kpath[i:i+2],eig[i:i+2,j],c=col[C[i,j]])
284- # sc = ax.scatter(X,
285- # Y,
286- # c=C,
287- # vmin=0,
288- # vmax=1,
289- # s=size,
290- # linewidths=None,
291- # marker='o',
292- # cmap="hsv",
293- # *args,
294- # **kwargs)
295283 return line
296284
297285
0 commit comments