@@ -98,8 +98,13 @@ def call(self, inputs, training=False):
98
98
99
99
return output
100
100
101
+ @tf .function
101
102
def predict (self , x ):
102
- return self .apply (tf .cast (x , tf .float32 ), training = False ).numpy ()
103
+ return self .apply (tf .cast (x , tf .float32 ), training = False )
104
+
105
+ @tf .function
106
+ def evaluate (self , x , y , training = False ):
107
+ return self .loss_fn (y , self .apply (tf .cast (x , tf .float32 ), training = training ))
103
108
104
109
@tf .function
105
110
def train_step_init (self , inputs , labels ):
@@ -113,13 +118,12 @@ def get_active_subnets(self):
113
118
if self .bn_flag :
114
119
beta = self .output_layer .output_weights .numpy () * self .output_layer .subnet_swicher .numpy ()
115
120
else :
116
- subnet_norm = [self .subnet_blocks .subnets [i ].subnet_bn . moving_variance . numpy ()[0 ] ** 0.5 for i in range (self .subnet_num )]
121
+ subnet_norm = [self .subnet_blocks .subnets [i ].subnet_norm . numpy ()[0 ] for i in range (self .subnet_num )]
117
122
beta = self .output_layer .output_weights .numpy () * np .array ([subnet_norm ]).reshape ([- 1 , 1 ]) * self .output_layer .subnet_swicher .numpy ()
118
123
119
124
subnets_scale = (np .abs (beta ) / np .sum (np .abs (beta ))).reshape ([- 1 ])
120
125
sorted_index = np .argsort (subnets_scale )
121
126
active_index = sorted_index [subnets_scale [sorted_index ].cumsum ()> self .beta_threshold ][::- 1 ]
122
- # active_index = sorted_index[subnets_scale[sorted_index]>self.beta_threshold][::-1]
123
127
return active_index , beta , subnets_scale
124
128
125
129
def fit (self , train_x , train_y ):
@@ -150,8 +154,8 @@ def fit(self, train_x, train_y):
150
154
batch_yy = tr_y [offset :(offset + self .batch_size )]
151
155
self .train_step_init (tf .cast (batch_xx , tf .float32 ), batch_yy )
152
156
153
- self .err_train .append (self .loss_fn ( tr_y , self . apply ( tf . cast ( tr_x , tf . float32 ) , training = True )). numpy ( ))
154
- self .err_val .append (self .loss_fn ( val_y , self . apply ( tf . cast ( val_x , tf . float32 ) , training = True )). numpy ( ))
157
+ self .err_train .append (self .evaluate ( tr_x , tr_y , training = True ))
158
+ self .err_val .append (self .evaluate ( val_x , val_y , training = True ))
155
159
if self .verbose & (epoch % 1 == 0 ):
156
160
print ("Training epoch: %d, train loss: %0.5f, val loss: %0.5f" %
157
161
(epoch + 1 , self .err_train [- 1 ], self .err_val [- 1 ]))
@@ -188,12 +192,13 @@ def fit(self, train_x, train_y):
188
192
batch_yy = tr_y [offset :(offset + self .batch_size )]
189
193
self .train_step_finetune (tf .cast (batch_xx , tf .float32 ), batch_yy )
190
194
191
- self .err_train .append (self .loss_fn ( tr_y , self . apply ( tf . cast ( tr_x , tf . float32 ) , training = True )). numpy ( ))
192
- self .err_val .append (self .loss_fn ( val_y , self . apply ( tf . cast ( val_x , tf . float32 ) , training = True )). numpy ( ))
195
+ self .err_train .append (self .evaluate ( tr_x , tr_y , training = True ))
196
+ self .err_val .append (self .evaluate ( val_x , val_y , training = True ))
193
197
if self .verbose & (epoch % 1 == 0 ):
194
198
print ("Tuning epoch: %d, train loss: %0.5f, val loss: %0.5f" %
195
199
(epoch + 1 , self .err_train [- 1 ], self .err_val [- 1 ]))
196
200
201
+ self .evaluate (train_x , train_y , training = True )
197
202
# record the key values in the network
198
203
self .subnet_input_min = []
199
204
self .subnet_input_max = []
@@ -240,9 +245,10 @@ def visualize(self, folder="./results", name="demo", dummy_name=None, save_eps=F
240
245
np .min (subnets_outputs ), np .max (subnets_outputs ), 6 ), 2 )
241
246
ax1 .set_yticks (yint )
242
247
ax1 .set_yticklabels (["{0: .2f}" .format (j ) for j in yint ])
243
- legend_style = mlines .Line2D ([], [], color = 'black' , marker = 'o' , linewidth = 0.0 , markersize = 6 ,
244
- label = 'Scale: ' + str (np .round (100 * subnets_scale [indice ], 1 )) + "%" )
245
- plt .legend (handles = [legend_style ], fontsize = 18 )
248
+ ax1 .set_ylim ([np .min (subnets_outputs ) - (np .max (subnets_outputs ) - np .min (subnets_outputs ))* 0.1 ,
249
+ np .max (subnets_outputs ) + (np .max (subnets_outputs ) - np .min (subnets_outputs ))* 0.25 ])
250
+ ax1 .text (0.25 , 0.9 ,'Scale: ' + str (np .round (100 * subnets_scale [indice ], 1 )) + "%" ,
251
+ fontsize = 16 , horizontalalignment = 'center' , verticalalignment = 'center' , transform = ax1 .transAxes )
246
252
247
253
ax2 = f .add_subplot (np .int (max_ids ), 2 , i * 2 + 2 )
248
254
ax2 .bar (np .arange (input_size ), coef_index .T [indice , :input_size ])
0 commit comments