Skip to content

Commit 3a502c4

Browse files
author
SAAS R7 User1
committed
update
1 parent d7eb656 commit 3a502c4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

exnn/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def fit(self, train_x, train_y):
226226
if self.verbose:
227227
print("Subnetwork pruning.")
228228

229+
self.evaluate(tr_x, tr_y, training=True) # update the batch normalization using all the training data
229230
active_me_index, active_categ_index, _, _ = self.get_active_subnets(self.beta_threshold)
230231
scal_factor = np.zeros((self.subnet_num + self.categ_variable_num, 1))
231232
scal_factor[active_me_index] = 1
@@ -267,6 +268,7 @@ def fit(self, train_x, train_y):
267268
# record the key values in the network
268269
self.subnet_input_min = []
269270
self.subnet_input_max = []
271+
self.evaluate(tr_x, tr_y, training=True) # update the batch normalization using all the training data
270272
for i in range(self.subnet_num):
271273
min_ = np.dot(train_x[:,self.noncateg_index_list], self.proj_layer.get_weights()[0])[:, i].min()
272274
max_ = np.dot(train_x[:,self.noncateg_index_list], self.proj_layer.get_weights()[0])[:, i].max()
@@ -346,4 +348,4 @@ def visualize(self, folder="./results/", name="demo", save_png=False, save_eps=F
346348
if save_png:
347349
f.savefig("%s.png" % save_path, bbox_inches='tight', dpi=100)
348350
if save_eps:
349-
f.savefig("%s.eps" % save_path, bbox_inches='tight', dpi=100)
351+
f.savefig("%s.eps" % save_path, bbox_inches='tight', dpi=100)

0 commit comments

Comments
 (0)