@@ -226,6 +226,7 @@ def fit(self, train_x, train_y):
226
226
if self .verbose :
227
227
print ("Subnetwork pruning." )
228
228
229
+ self .evaluate (tr_x , tr_y , training = True ) # update the batch normalization using all the training data
229
230
active_me_index , active_categ_index , _ , _ = self .get_active_subnets (self .beta_threshold )
230
231
scal_factor = np .zeros ((self .subnet_num + self .categ_variable_num , 1 ))
231
232
scal_factor [active_me_index ] = 1
@@ -267,6 +268,7 @@ def fit(self, train_x, train_y):
267
268
# record the key values in the network
268
269
self .subnet_input_min = []
269
270
self .subnet_input_max = []
271
+ self .evaluate (tr_x , tr_y , training = True ) # update the batch normalization using all the training data
270
272
for i in range (self .subnet_num ):
271
273
min_ = np .dot (train_x [:,self .noncateg_index_list ], self .proj_layer .get_weights ()[0 ])[:, i ].min ()
272
274
max_ = np .dot (train_x [:,self .noncateg_index_list ], self .proj_layer .get_weights ()[0 ])[:, i ].max ()
@@ -346,4 +348,4 @@ def visualize(self, folder="./results/", name="demo", save_png=False, save_eps=F
346
348
if save_png :
347
349
f .savefig ("%s.png" % save_path , bbox_inches = 'tight' , dpi = 100 )
348
350
if save_eps :
349
- f .savefig ("%s.eps" % save_path , bbox_inches = 'tight' , dpi = 100 )
351
+ f .savefig ("%s.eps" % save_path , bbox_inches = 'tight' , dpi = 100 )
0 commit comments