@@ -210,8 +210,8 @@ <h1>Source code for exnn.exnn</h1><div class="highlight"><pre>
210
210
< span class ="sd "> :type l1_subnet: float</ span >
211
211
< span class ="sd "> :param l1_subnet: optional, default=0.001, the strength of L1 penalty for scaling layer.</ span >
212
212
213
- < span class ="sd "> :type smooth_lambda : float</ span >
214
- < span class ="sd "> :param smooth_lambda : optional, default=0.000001, the strength of roughness penalty for subnetworks.</ span >
213
+ < span class ="sd "> :type l2_smooth : float</ span >
214
+ < span class ="sd "> :param l2_smooth : optional, default=0.000001, the strength of roughness penalty for subnetworks.</ span >
215
215
216
216
< span class ="sd "> :type verbose: bool</ span >
217
217
< span class ="sd "> :param verbose: optional, default=False. If True, detailed messages will be printed.</ span >
@@ -233,7 +233,7 @@ <h1>Source code for exnn.exnn</h1><div class="highlight"><pre>
233
233
234
234
< span class ="k "> def</ span > < span class ="nf "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> meta_info</ span > < span class ="p "> ,</ span > < span class ="n "> subnet_num</ span > < span class ="p "> ,</ span > < span class ="n "> subnet_arch</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="mi "> 10</ span > < span class ="p "> ,</ span > < span class ="mi "> 6</ span > < span class ="p "> ],</ span > < span class ="n "> task_type</ span > < span class ="o "> =</ span > < span class ="s2 "> "Regression"</ span > < span class ="p "> ,</ span >
235
235
< span class ="n "> activation_func</ span > < span class ="o "> =</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> tanh</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="o "> =</ span > < span class ="mi "> 1000</ span > < span class ="p "> ,</ span > < span class ="n "> training_epochs</ span > < span class ="o "> =</ span > < span class ="mi "> 10000</ span > < span class ="p "> ,</ span > < span class ="n "> lr_bp</ span > < span class ="o "> =</ span > < span class ="mf "> 0.001</ span > < span class ="p "> ,</ span > < span class ="n "> lr_cl</ span > < span class ="o "> =</ span > < span class ="mf "> 0.1</ span > < span class ="p "> ,</ span >
236
- < span class ="n "> beta_threshold</ span > < span class ="o "> =</ span > < span class ="mf "> 0.05</ span > < span class ="p "> ,</ span > < span class ="n "> tuning_epochs</ span > < span class ="o "> =</ span > < span class ="mi "> 500</ span > < span class ="p "> ,</ span > < span class ="n "> l1_proj</ span > < span class ="o "> =</ span > < span class ="mf "> 0.001</ span > < span class ="p "> ,</ span > < span class ="n "> l1_subnet</ span > < span class ="o "> =</ span > < span class ="mf "> 0.001</ span > < span class ="p "> ,</ span > < span class ="n "> smooth_lambda </ span > < span class ="o "> =</ span > < span class ="mf "> 0.000001</ span > < span class ="p "> ,</ span >
236
+ < span class ="n "> beta_threshold</ span > < span class ="o "> =</ span > < span class ="mf "> 0.05</ span > < span class ="p "> ,</ span > < span class ="n "> tuning_epochs</ span > < span class ="o "> =</ span > < span class ="mi "> 500</ span > < span class ="p "> ,</ span > < span class ="n "> l1_proj</ span > < span class ="o "> =</ span > < span class ="mf "> 0.001</ span > < span class ="p "> ,</ span > < span class ="n "> l1_subnet</ span > < span class ="o "> =</ span > < span class ="mf "> 0.001</ span > < span class ="p "> ,</ span > < span class ="n "> l2_smooth </ span > < span class ="o "> =</ span > < span class ="mf "> 0.000001</ span > < span class ="p "> ,</ span >
237
237
< span class ="n "> verbose</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> val_ratio</ span > < span class ="o "> =</ span > < span class ="mf "> 0.2</ span > < span class ="p "> ,</ span > < span class ="n "> early_stop_thres</ span > < span class ="o "> =</ span > < span class ="mi "> 1000</ span > < span class ="p "> ,</ span > < span class ="n "> random_state</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> ):</ span >
238
238
239
239
< span class ="nb "> super</ span > < span class ="p "> (</ span > < span class ="n "> ExNN</ span > < span class ="p "> ,</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="n "> meta_info</ span > < span class ="o "> =</ span > < span class ="n "> meta_info</ span > < span class ="p "> ,</ span >
@@ -246,7 +246,7 @@ <h1>Source code for exnn.exnn</h1><div class="highlight"><pre>
246
246
< span class ="n "> lr_bp</ span > < span class ="o "> =</ span > < span class ="n "> lr_bp</ span > < span class ="p "> ,</ span >
247
247
< span class ="n "> l1_proj</ span > < span class ="o "> =</ span > < span class ="n "> l1_proj</ span > < span class ="p "> ,</ span >
248
248
< span class ="n "> l1_subnet</ span > < span class ="o "> =</ span > < span class ="n "> l1_subnet</ span > < span class ="p "> ,</ span >
249
- < span class ="n "> smooth_lambda </ span > < span class ="o "> =</ span > < span class ="n "> smooth_lambda </ span > < span class ="p "> ,</ span >
249
+ < span class ="n "> l2_smooth </ span > < span class ="o "> =</ span > < span class ="n "> l2_smooth </ span > < span class ="p "> ,</ span >
250
250
< span class ="n "> batch_size</ span > < span class ="o "> =</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span >
251
251
< span class ="n "> training_epochs</ span > < span class ="o "> =</ span > < span class ="n "> training_epochs</ span > < span class ="p "> ,</ span >
252
252
< span class ="n "> tuning_epochs</ span > < span class ="o "> =</ span > < span class ="n "> tuning_epochs</ span > < span class ="p "> ,</ span >
@@ -261,12 +261,12 @@ <h1>Source code for exnn.exnn</h1><div class="highlight"><pre>
261
261
< span class ="k "> def</ span > < span class ="nf "> train_step_init</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> inputs</ span > < span class ="p "> ,</ span > < span class ="n "> labels</ span > < span class ="p "> ):</ span >
262
262
< span class ="k "> with</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> GradientTape</ span > < span class ="p "> ()</ span > < span class ="k "> as</ span > < span class ="n "> tape_cl</ span > < span class ="p "> :</ span >
263
263
< span class ="k "> with</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> GradientTape</ span > < span class ="p "> ()</ span > < span class ="k "> as</ span > < span class ="n "> tape_bp</ span > < span class ="p "> :</ span >
264
- < span class ="n "> pred</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n " > apply </ span > < span class ="p "> (</ span > < span class ="n "> inputs</ span > < span class ="p "> ,</ span > < span class ="n "> training</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
264
+ < span class ="n "> pred</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="fm " > __call__ </ span > < span class ="p "> (</ span > < span class ="n "> inputs</ span > < span class ="p "> ,</ span > < span class ="n "> training</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
265
265
< span class ="n "> pred_loss</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> loss_fn</ span > < span class ="p "> (</ span > < span class ="n "> labels</ span > < span class ="p "> ,</ span > < span class ="n "> pred</ span > < span class ="p "> )</ span >
266
266
< span class ="n "> regularization_loss</ span > < span class ="o "> =</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> math</ span > < span class ="o "> .</ span > < span class ="n "> add_n</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> proj_layer</ span > < span class ="o "> .</ span > < span class ="n "> losses</ span > < span class ="o "> +</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> output_layer</ span > < span class ="o "> .</ span > < span class ="n "> losses</ span > < span class ="p "> )</ span >
267
267
< span class ="n "> cl_loss</ span > < span class ="o "> =</ span > < span class ="n "> pred_loss</ span > < span class ="o "> +</ span > < span class ="n "> regularization_loss</ span >
268
268
< span class ="n "> bp_loss</ span > < span class ="o "> =</ span > < span class ="n "> pred_loss</ span > < span class ="o "> +</ span > < span class ="n "> regularization_loss</ span >
269
- < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> smooth_lambda </ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
269
+ < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> l2_smooth </ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
270
270
< span class ="n "> smoothness_loss</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> subnet_blocks</ span > < span class ="o "> .</ span > < span class ="n "> smooth_loss</ span >
271
271
< span class ="n "> bp_loss</ span > < span class ="o "> +=</ span > < span class ="n "> smoothness_loss</ span >
272
272
@@ -291,10 +291,10 @@ <h1>Source code for exnn.exnn</h1><div class="highlight"><pre>
291
291
< span class ="k "> def</ span > < span class ="nf "> train_step_finetune</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> inputs</ span > < span class ="p "> ,</ span > < span class ="n "> labels</ span > < span class ="p "> ):</ span >
292
292
293
293
< span class ="k "> with</ span > < span class ="n "> tf</ span > < span class ="o "> .</ span > < span class ="n "> GradientTape</ span > < span class ="p "> ()</ span > < span class ="k "> as</ span > < span class ="n "> tape</ span > < span class ="p "> :</ span >
294
- < span class ="n "> pred</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n " > apply </ span > < span class ="p "> (</ span > < span class ="n "> inputs</ span > < span class ="p "> ,</ span > < span class ="n "> training</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
294
+ < span class ="n "> pred</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="fm " > __call__ </ span > < span class ="p "> (</ span > < span class ="n "> inputs</ span > < span class ="p "> ,</ span > < span class ="n "> training</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
295
295
< span class ="n "> pred_loss</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> loss_fn</ span > < span class ="p "> (</ span > < span class ="n "> labels</ span > < span class ="p "> ,</ span > < span class ="n "> pred</ span > < span class ="p "> )</ span >
296
296
< span class ="n "> total_loss</ span > < span class ="o "> =</ span > < span class ="n "> pred_loss</ span >
297
- < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> smooth_lambda </ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
297
+ < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> l2_smooth </ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
298
298
< span class ="n "> smoothness_loss</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> subnet_blocks</ span > < span class ="o "> .</ span > < span class ="n "> smooth_loss</ span >
299
299
< span class ="n "> total_loss</ span > < span class ="o "> +=</ span > < span class ="n "> smoothness_loss</ span >
300
300
0 commit comments