Skip to content

Commit edd0cbf

Browse files
committed
fix bug
1 parent b45e9d3 commit edd0cbf

File tree

2 files changed

+34
-35
lines changed

2 files changed

+34
-35
lines changed

Diff for: stage2_StrDA.py

+32-33
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def pseudo_labeling(args, model, converter, target_data, adapting_list, round):
109109
return list_adapt_data, list_pseudo_data, mean_conf
110110

111111

112-
def self_training(args, filtered_parameters, model, criterion, converter, \
113-
source_loader, valid_loader, adapting_loader, mean_conf, round = 0):
112+
def self_training(args, filtered_parameters, model, criterion, converter, relative_path, \
113+
source_loader, valid_loader, adapting_loader, mean_conf, round=0):
114114

115115
num_iter = (args.total_iter // args.val_interval) // args.num_subsets * args.val_interval
116116

@@ -140,7 +140,8 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
140140
best_score = float("-inf")
141141
score_descent = 0
142142

143-
log = ""
143+
log = "-" * 80 +"\n"
144+
log += "Start Self-Training (Scene Text Recognition - STR)...\n"
144145

145146
model.train()
146147
# training loop
@@ -163,15 +164,14 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
163164
infer_time,
164165
length_of_data,
165166
) = validation(model, criterion, valid_loader, converter, args)
166-
model.train()
167167

168168
if (current_score >= best_score):
169169
score_descent = 0
170170

171171
best_score = current_score
172172
torch.save(
173173
model.state_dict(),
174-
f"trained_model/{args.method}/StrDA_round{round}.pth",
174+
f"trained_model/{relative_path}/{args.model}_round{round}.pth",
175175
)
176176
else:
177177
score_descent += 1
@@ -188,27 +188,28 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
188188

189189
log += valid_log
190190

191-
log += "-" * 80 +"\n"
191+
log += "\n" + "-" * 80 +"\n"
192192

193193
train_loss_avg.reset()
194194
source_loss_avg.reset()
195195
adapting_loss_avg.reset()
196196

197197
if iteration == num_iter:
198-
log += f"Stop training at iteration: {iteration}!\n"
199-
print(f"Stop training at iteration: {iteration}!\n")
198+
log += f"Stop training at iteration {iteration}!\n"
199+
print(f"Stop training at iteration {iteration}!\n")
200200
break
201201

202202
# training part
203-
""" loss of source domain """
203+
model.train()
204+
""" Loss of labeled data (source domain) """
204205
try:
205206
images_source_tensor, labels_source = next(source_loader_iter)
206207
except StopIteration:
207208
del source_loader_iter
208209
source_loader_iter = iter(source_loader)
209210
images_source_tensor, labels_source = next(source_loader_iter)
210211

211-
images_source = images_source_tensor.to(device)
212+
images_source = images_source_tensor.to(device)
212213
labels_source_index, labels_source_length = converter.encode(
213214
labels_source, batch_max_length=args.batch_max_length
214215
)
@@ -226,27 +227,27 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
226227
preds_source.view(-1, preds_source.shape[-1]), target_source.contiguous().view(-1)
227228
)
228229

229-
""" loss of semi """
230+
""" Loss of pseudo-labeled data (target domain) """
230231
try:
231-
images_unlabel_tensor, labels_adapting = next(adapting_loader_iter)
232+
images_adapting_tensor, labels_adapting = next(adapting_loader_iter)
232233
except StopIteration:
233234
del adapting_loader_iter
234235
adapting_loader_iter = iter(adapting_loader)
235-
images_unlabel_tensor, labels_adapting = next(adapting_loader_iter)
236+
images_adapting_tensor, labels_adapting = next(adapting_loader_iter)
236237

237-
images_unlabel = images_unlabel_tensor.to(device)
238+
images_adapting = images_adapting_tensor.to(device)
238239
labels_adapting_index, labels_adapting_length = converter.encode(
239240
labels_adapting, batch_max_length=args.batch_max_length
240241
)
241242

242-
batch_unlabel_size = len(labels_adapting)
243+
batch_adapting_size = len(labels_adapting)
243244
if args.Prediction == "CTC":
244-
preds_adapting = model(images_unlabel)
245-
preds_adapting_size = torch.IntTensor([preds_adapting.size(1)] * batch_unlabel_size)
245+
preds_adapting = model(images_adapting)
246+
preds_adapting_size = torch.IntTensor([preds_adapting.size(1)] * batch_adapting_size)
246247
preds_adapting_log_softmax = preds_adapting.log_softmax(2).permute(1, 0, 2)
247248
loss_adapting = criterion(preds_adapting_log_softmax, labels_adapting_index, preds_adapting_size, labels_adapting_length)
248249
else:
249-
preds_adapting = model(images_unlabel, labels_adapting_index[:, :-1]) # align with Attention.forward
250+
preds_adapting = model(images_adapting, labels_adapting_index[:, :-1]) # align with Attention.forward
250251
target_adapting = labels_adapting_index[:, 1:] # without [SOS] Symbol
251252
loss_adapting = criterion(
252253
preds_adapting.view(-1, preds_adapting.shape[-1]), target_adapting.contiguous().view(-1)
@@ -272,11 +273,12 @@ def self_training(args, filtered_parameters, model, criterion, converter, \
272273
# save model
273274
# torch.save(
274275
# model.state_dict(),
275-
# f"trained_model/{args.method}/StrDA_round{round}.pth",
276+
# f"trained_model/{relative_path}/{args.model}_round{round}.pth",
276277
# )
277278

278279
# save log
279-
print(log, file= open(f"log/{args.method}/log_self_training_round{round}.txt", "w"))
280+
log += f"Model is saved at trained_model/{relative_path}/{args.model}_round{round}.pth"
281+
print(log, file= open(f"log/{relative_path}/log_self_training_round{round}.txt", "w"))
280282

281283
# free cache
282284
torch.cuda.empty_cache()
@@ -439,36 +441,33 @@ def main(args):
439441

440442
# self-training
441443
print(dashed_line)
442-
print("- Seft-training...")
443-
main_log += "\n- Seft-training"
444-
445-
# adjust mean_conf (round_down)
446-
mean_conf = int(mean_conf * 10)
444+
print("- Start Self-Training (Scene Text Recognition - STR)...")
445+
main_log += "\n- Start Self-Training (Scene Text Recognition - STR)..."
447446

448447
self_training_start = time.time()
449448
if (round >= args.checkpoint):
450-
self_training(args, filtered_parameters, model, criterion, converter, \
449+
self_training(args, filtered_parameters, model, criterion, converter, relative_path, \
451450
source_loader, valid_loader, adapting_loader, mean_conf, round + 1)
452451
self_training_end = time.time()
453452

454453
print(f"Processing time: {self_training_end - self_training_start}s")
455-
print(f"Saved log for adapting round to: 'log/{args.method}/log_self_training_round{round + 1}.txt'")
456-
adapt_log += f"\nProcessing time: {self_training_end - self_training_start}s"
457-
adapt_log += f"\nSaved log for adapting round to: 'log/{args.method}/log_self_training_round{round + 1}.txt'"
454+
print(f"Model is saved at trained_model/{relative_path}/{args.model}_round{round}.pth")
455+
print(f"Saved log for adapting round to: 'log/{relative_path}/log_self_training_round{round + 1}.txt'")
458456

457+
main_log += f"\nProcessing time: {self_training_end - self_training_start}s"
458+
main_log += f"\nModel is saved at trained_model/{relative_path}/{args.model}_round{round}.pth"
459+
main_log += f"\nSaved log for adapting round to: 'log/{relative_path}/log_self_training_round{round + 1}.txt'"
459460
main_log += "\n" + dashed_line + "\n"
460461

461-
print(dashed_line)
462-
print(dashed_line)
463-
print(dashed_line)
462+
print(dashed_line * 3)
464463

465464
# free cache
466465
torch.cuda.empty_cache()
467466

468467
# save log
469468
print(main_log, file= open(f"log/{args.method}/log_StrDA.txt", "w"))
470469

471-
return
470+
return
472471

473472
if __name__ == "__main__":
474473
""" Argument """

Diff for: supervised_learning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def main(args):
210210
print(valid_log)
211211

212212
main_log += valid_log
213-
214-
main_log += dashed_line + "\n"
213+
main_log += "\n" + dashed_line + "\n"
215214

216215
train_loss_avg.reset()
217216

@@ -221,6 +220,7 @@ def main(args):
221220
# save log
222221
print("Training is done!")
223222
main_log += "Training is done!"
223+
main_log += f"Model is saved at trained_model/{args.model}_supervised.pth"
224224
print(main_log, file= open(f"log/{args.model}_supervised.txt", "w"))
225225

226226
print(f"Model is saved at trained_model/{args.model}_supervised.pth")

0 commit comments

Comments
 (0)