@@ -226,6 +226,7 @@ def fit(self, train_x, train_y):
226226 if self .verbose :
227227 print ("Subnetwork pruning." )
228228
229+ self .evaluate (tr_x , tr_y , training = True ) # update the batch normalization using all the training data
229230 active_me_index , active_categ_index , _ , _ = self .get_active_subnets (self .beta_threshold )
230231 scal_factor = np .zeros ((self .subnet_num + self .categ_variable_num , 1 ))
231232 scal_factor [active_me_index ] = 1
@@ -267,6 +268,7 @@ def fit(self, train_x, train_y):
267268 # record the key values in the network
268269 self .subnet_input_min = []
269270 self .subnet_input_max = []
271+ self .evaluate (tr_x , tr_y , training = True ) # update the batch normalization using all the training data
270272 for i in range (self .subnet_num ):
271273 min_ = np .dot (train_x [:,self .noncateg_index_list ], self .proj_layer .get_weights ()[0 ])[:, i ].min ()
272274 max_ = np .dot (train_x [:,self .noncateg_index_list ], self .proj_layer .get_weights ()[0 ])[:, i ].max ()
@@ -346,4 +348,4 @@ def visualize(self, folder="./results/", name="demo", save_png=False, save_eps=F
346348 if save_png :
347349 f .savefig ("%s.png" % save_path , bbox_inches = 'tight' , dpi = 100 )
348350 if save_eps :
349- f .savefig ("%s.eps" % save_path , bbox_inches = 'tight' , dpi = 100 )
351+ f .savefig ("%s.eps" % save_path , bbox_inches = 'tight' , dpi = 100 )
0 commit comments