|
| 1 | +#!/usr/bin/python |
| 2 | +# encoding: utf-8 |
| 3 | + |
1 | 4 | from __future__ import print_function
|
2 | 5 | import argparse
|
3 | 6 | import random
|
|
14 | 17 |
|
15 | 18 | import models.crnn as crnn
|
16 | 19 |
|
| 20 | +from logger import logger |
| 21 | + |
17 | 22 | parser = argparse.ArgumentParser()
|
18 | 23 | parser.add_argument('--trainroot', required=True, help='path to dataset')
|
19 | 24 | parser.add_argument('--valroot', required=True, help='path to dataset')
|
|
28 | 33 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
|
29 | 34 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
|
30 | 35 | parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
|
31 |
| -parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz') |
| 36 | +parser.add_argument('--alphabet', type=str, default='あいうえおかきくけこ') |
32 | 37 | parser.add_argument('--experiment', default=None, help='Where to store samples and models')
|
33 | 38 | parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
|
34 | 39 | parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
|
|
64 | 69 | sampler = None
|
65 | 70 | train_loader = torch.utils.data.DataLoader(
|
66 | 71 | train_dataset, batch_size=opt.batchSize,
|
67 |
| - shuffle=True, sampler=sampler, |
| 72 | + shuffle=True, |
| 73 | + # sampler=sampler, |
68 | 74 | num_workers=int(opt.workers),
|
69 | 75 | collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
|
70 | 76 | test_dataset = dataset.lmdbDataset(
|
71 | 77 | root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))
|
72 | 78 |
|
73 |
| -nclass = len(opt.alphabet) + 1 |
| 79 | +alphabet_u = opt.alphabet.decode('utf-8') |
| 80 | +print(alphabet_u) |
| 81 | + |
| 82 | +nclass = len(alphabet_u) + 1 |
| 83 | +print("Number of classes: %s" % (nclass)) |
74 | 84 | nc = 1
|
75 | 85 |
|
76 |
| -converter = utils.strLabelConverter(opt.alphabet) |
| 86 | +converter = utils.strLabelConverter(alphabet_u) |
77 | 87 | criterion = CTCLoss()
|
78 | 88 |
|
79 | 89 |
|
@@ -122,7 +132,7 @@ def weights_init(m):
|
122 | 132 |
|
123 | 133 |
|
124 | 134 | def val(net, dataset, criterion, max_iter=100):
|
125 |
| - print('Start val') |
| 135 | + print('================ Start val') |
126 | 136 |
|
127 | 137 | for p in crnn.parameters():
|
128 | 138 | p.requires_grad = False
|
@@ -199,14 +209,17 @@ def trainBatch(net, criterion, optimizer):
|
199 | 209 | i += 1
|
200 | 210 |
|
201 | 211 | if i % opt.displayInterval == 0:
|
202 |
| - print('[%d/%d][%d/%d] Loss: %f' % |
203 |
| - (epoch, opt.niter, i, len(train_loader), loss_avg.val())) |
| 212 | + logger.info("%s: Epoch: [%d/%d][%d/%d] Loss: %f" % |
| 213 | + (i, epoch, opt.niter, i, len(train_loader), loss_avg.val())) |
204 | 214 | loss_avg.reset()
|
205 | 215 |
|
206 | 216 | if i % opt.valInterval == 0:
|
| 217 | + logger.info("%s: Validating model" % (i)) |
207 | 218 | val(crnn, test_dataset, criterion)
|
208 | 219 |
|
209 |
| - # do checkpointing |
210 |
| - if i % opt.saveInterval == 0: |
211 |
| - torch.save( |
212 |
| - crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i)) |
| 220 | + if epoch % 10 == 0: |
| 221 | + # save model per 10 epoch |
| 222 | + logger.info("%s: Save model" % (i)) |
| 223 | + torch.save( |
| 224 | + crnn.state_dict(), '{0}/netCRNN_{1}.pth'.format(opt.experiment, epoch)) |
| 225 | + |
0 commit comments