@@ -109,8 +109,8 @@ def pseudo_labeling(args, model, converter, target_data, adapting_list, round):
109
109
return list_adapt_data , list_pseudo_data , mean_conf
110
110
111
111
112
- def self_training (args , filtered_parameters , model , criterion , converter , \
113
- source_loader , valid_loader , adapting_loader , mean_conf , round = 0 ):
112
+ def self_training (args , filtered_parameters , model , criterion , converter , relative_path , \
113
+ source_loader , valid_loader , adapting_loader , mean_conf , round = 0 ):
114
114
115
115
num_iter = (args .total_iter // args .val_interval ) // args .num_subsets * args .val_interval
116
116
@@ -140,7 +140,8 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
140
140
best_score = float ("-inf" )
141
141
score_descent = 0
142
142
143
- log = ""
143
+ log = "-" * 80 + "\n "
144
+ log += "Start Self-Training (Scene Text Recognition - STR)...\n "
144
145
145
146
model .train ()
146
147
# training loop
@@ -163,15 +164,14 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
163
164
infer_time ,
164
165
length_of_data ,
165
166
) = validation (model , criterion , valid_loader , converter , args )
166
- model .train ()
167
167
168
168
if (current_score >= best_score ):
169
169
score_descent = 0
170
170
171
171
best_score = current_score
172
172
torch .save (
173
173
model .state_dict (),
174
- f"trained_model/{ args .method } /StrDA_round { round } .pth" ,
174
+ f"trained_model/{ relative_path } / { args .model } _round { round } .pth" ,
175
175
)
176
176
else :
177
177
score_descent += 1
@@ -188,27 +188,28 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
188
188
189
189
log += valid_log
190
190
191
- log += "-" * 80 + "\n "
191
+ log += "\n " + " -" * 80 + "\n "
192
192
193
193
train_loss_avg .reset ()
194
194
source_loss_avg .reset ()
195
195
adapting_loss_avg .reset ()
196
196
197
197
if iteration == num_iter :
198
- log += f"Stop training at iteration: { iteration } !\n "
199
- print (f"Stop training at iteration: { iteration } !\n " )
198
+ log += f"Stop training at iteration { iteration } !\n "
199
+ print (f"Stop training at iteration { iteration } !\n " )
200
200
break
201
201
202
202
# training part
203
- """ loss of source domain """
203
+ model .train ()
204
+ """ Loss of labeled data (source domain) """
204
205
try :
205
206
images_source_tensor , labels_source = next (source_loader_iter )
206
207
except StopIteration :
207
208
del source_loader_iter
208
209
source_loader_iter = iter (source_loader )
209
210
images_source_tensor , labels_source = next (source_loader_iter )
210
211
211
- images_source = images_source_tensor .to (device )
212
+ images_source = images_source_tensor .to (device )
212
213
labels_source_index , labels_source_length = converter .encode (
213
214
labels_source , batch_max_length = args .batch_max_length
214
215
)
@@ -226,27 +227,27 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
226
227
preds_source .view (- 1 , preds_source .shape [- 1 ]), target_source .contiguous ().view (- 1 )
227
228
)
228
229
229
- """ loss of semi """
230
+ """ Loss of pseudo-labeled data (target domain) """
230
231
try :
231
- images_unlabel_tensor , labels_adapting = next (adapting_loader_iter )
232
+ images_adapting_tensor , labels_adapting = next (adapting_loader_iter )
232
233
except StopIteration :
233
234
del adapting_loader_iter
234
235
adapting_loader_iter = iter (adapting_loader )
235
- images_unlabel_tensor , labels_adapting = next (adapting_loader_iter )
236
+ images_adapting_tensor , labels_adapting = next (adapting_loader_iter )
236
237
237
- images_unlabel = images_unlabel_tensor .to (device )
238
+ images_adapting = images_adapting_tensor .to (device )
238
239
labels_adapting_index , labels_adapting_length = converter .encode (
239
240
labels_adapting , batch_max_length = args .batch_max_length
240
241
)
241
242
242
- batch_unlabel_size = len (labels_adapting )
243
+ batch_adapting_size = len (labels_adapting )
243
244
if args .Prediction == "CTC" :
244
- preds_adapting = model (images_unlabel )
245
- preds_adapting_size = torch .IntTensor ([preds_adapting .size (1 )] * batch_unlabel_size )
245
+ preds_adapting = model (images_adapting )
246
+ preds_adapting_size = torch .IntTensor ([preds_adapting .size (1 )] * batch_adapting_size )
246
247
preds_adapting_log_softmax = preds_adapting .log_softmax (2 ).permute (1 , 0 , 2 )
247
248
loss_adapting = criterion (preds_adapting_log_softmax , labels_adapting_index , preds_adapting_size , labels_adapting_length )
248
249
else :
249
- preds_adapting = model (images_unlabel , labels_adapting_index [:, :- 1 ]) # align with Attention.forward
250
+ preds_adapting = model (images_adapting , labels_adapting_index [:, :- 1 ]) # align with Attention.forward
250
251
target_adapting = labels_adapting_index [:, 1 :] # without [SOS] Symbol
251
252
loss_adapting = criterion (
252
253
preds_adapting .view (- 1 , preds_adapting .shape [- 1 ]), target_adapting .contiguous ().view (- 1 )
@@ -272,11 +273,12 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
272
273
# save model
273
274
# torch.save(
274
275
# model.state_dict(),
275
- # f"trained_model/{args.method}/StrDA_round {round}.pth",
276
+ # f"trained_model/{relative_path}/{ args.model}_round {round}.pth",
276
277
# )
277
278
278
279
# save log
279
- print (log , file = open (f"log/{ args .method } /log_self_training_round{ round } .txt" , "w" ))
280
+ log += f"Model is saved at trained_model/{ relative_path } /{ args .model } _round{ round } .pth"
281
+ print (log , file = open (f"log/{ relative_path } /log_self_training_round{ round } .txt" , "w" ))
280
282
281
283
# free cache
282
284
torch .cuda .empty_cache ()
@@ -439,36 +441,33 @@ def main(args):
439
441
440
442
# self-training
441
443
print (dashed_line )
442
- print ("- Seft-training..." )
443
- main_log += "\n - Seft-training"
444
-
445
- # adjust mean_conf (round_down)
446
- mean_conf = int (mean_conf * 10 )
444
+ print ("- Start Self-Training (Scene Text Recognition - STR)..." )
445
+ main_log += "\n - Start Self-Training (Scene Text Recognition - STR)..."
447
446
448
447
self_training_start = time .time ()
449
448
if (round >= args .checkpoint ):
450
- self_training (args , filtered_parameters , model , criterion , converter , \
449
+ self_training (args , filtered_parameters , model , criterion , converter , relative_path , \
451
450
source_loader , valid_loader , adapting_loader , mean_conf , round + 1 )
452
451
self_training_end = time .time ()
453
452
454
453
print (f"Processing time: { self_training_end - self_training_start } s" )
455
- print (f"Saved log for adapting round to: 'log/{ args .method } /log_self_training_round{ round + 1 } .txt'" )
456
- adapt_log += f"\n Processing time: { self_training_end - self_training_start } s"
457
- adapt_log += f"\n Saved log for adapting round to: 'log/{ args .method } /log_self_training_round{ round + 1 } .txt'"
454
+ print (f"Model is saved at trained_model/{ relative_path } /{ args .model } _round{ round } .pth" )
455
+ print (f"Saved log for adapting round to: 'log/{ relative_path } /log_self_training_round{ round + 1 } .txt'" )
458
456
457
+ main_log += f"\n Processing time: { self_training_end - self_training_start } s"
458
+ main_log += f"\n Model is saved at trained_model/{ relative_path } /{ args .model } _round{ round } .pth"
459
+ main_log += f"\n Saved log for adapting round to: 'log/{ relative_path } /log_self_training_round{ round + 1 } .txt'"
459
460
main_log += "\n " + dashed_line + "\n "
460
461
461
- print (dashed_line )
462
- print (dashed_line )
463
- print (dashed_line )
462
+ print (dashed_line * 3 )
464
463
465
464
# free cache
466
465
torch .cuda .empty_cache ()
467
466
468
467
# save log
469
468
print (main_log , file = open (f"log/{ args .method } /log_StrDA.txt" , "w" ))
470
469
471
- return
470
+ return
472
471
473
472
if __name__ == "__main__" :
474
473
""" Argument """
0 commit comments