@@ -467,7 +467,6 @@ def visualize(self, folder="./results/", name="demo", save_png=False, save_eps=F
467
467
os .makedirs (folder )
468
468
fig .savefig ("%s.png" % save_path , bbox_inches = "tight" , dpi = 100 )
469
469
470
-
471
470
def visualize_new (self , cols_per_row = 3 , subnet_num = 10 ** 5 , dummy_subnet_num = 10 ** 5 , show_indices = 10 ** 5 ,
472
471
folder = "./results/" , name = "demo" , save_png = False , save_eps = False ):
473
472
@@ -484,7 +483,7 @@ def visualize_new(self, cols_per_row=3, subnet_num=10**5, dummy_subnet_num=10**5
484
483
xlim_min = - max (np .abs (projection_indices .min () - 0.1 ), np .abs (projection_indices .max () + 0.1 ))
485
484
xlim_max = max (np .abs (projection_indices .min () - 0.1 ), np .abs (projection_indices .max () + 0.1 ))
486
485
for idx , (key , item ) in enumerate (active_subnets ):
487
-
486
+
488
487
indice = item ["indice" ]
489
488
inner = outer [idx ].subgridspec (2 , 2 , wspace = 0.15 , height_ratios = [6 , 1 ], width_ratios = [3 , 1 ])
490
489
ax1_main = fig .add_subplot (inner [0 , 0 ])
@@ -498,7 +497,9 @@ def visualize_new(self, cols_per_row=3, subnet_num=10**5, dummy_subnet_num=10**5
498
497
if coef_index [np .argmax (np .abs (coef_index [:, indice ])), indice ] < 0 :
499
498
coef_index [:, indice ] = - coef_index [:, indice ]
500
499
xgrid = - xgrid
501
-
500
+ bins = - bins [::- 1 ]
501
+ density = density [::- 1 ]
502
+
502
503
ax1_main .plot (xgrid , ygrid , color = "red" )
503
504
ax1_main .set_xticklabels ([])
504
505
ax1_main .set_title ("SIM " + str (idx + 1 ) +
0 commit comments