|
| 1 | +From 1677e9c5c23be7f4ff46904e7128bf129aacbb93 Mon Sep 17 00:00:00 2001 |
| 2 | +From: "vugia.truong" < [email protected]> |
| 3 | +Date: Mon, 18 Dec 2017 20:52:21 +0900 |
| 4 | +Subject: [PATCH] Add train file crnn_main now can use utf-8 labels |
| 5 | + |
| 6 | +--- |
| 7 | + crnn_main.py | 35 ++++++++++++++++++++++++----------- |
| 8 | + keys.py | 4 ++++ |
| 9 | + models/crnn.py | 3 +++ |
| 10 | + test/test_utils.py | 3 ++- |
| 11 | + train.sh | 11 +++++++++++ |
| 12 | + utils.py | 6 +++--- |
| 13 | + 6 files changed, 47 insertions(+), 15 deletions(-) |
| 14 | + create mode 100644 keys.py |
| 15 | + create mode 100644 train.sh |
| 16 | + |
| 17 | +diff --git a/crnn_main.py b/crnn_main.py |
| 18 | +index 876ffb6..dac8c20 100644 |
| 19 | +--- a/crnn_main.py |
| 20 | ++++ b/crnn_main.py |
| 21 | +@@ -1,3 +1,6 @@ |
| 22 | ++#!/usr/bin/python |
| 23 | ++# encoding: utf-8 |
| 24 | ++ |
| 25 | + from __future__ import print_function |
| 26 | + import argparse |
| 27 | + import random |
| 28 | +@@ -14,6 +17,8 @@ import dataset |
| 29 | + |
| 30 | + import models.crnn as crnn |
| 31 | + |
| 32 | ++from logger import logger |
| 33 | ++ |
| 34 | + parser = argparse.ArgumentParser() |
| 35 | + parser.add_argument('--trainroot', required=True, help='path to dataset') |
| 36 | + parser.add_argument('--valroot', required=True, help='path to dataset') |
| 37 | +@@ -28,7 +33,7 @@ parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. de |
| 38 | + parser.add_argument('--cuda', action='store_true', help='enables cuda') |
| 39 | + parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') |
| 40 | + parser.add_argument('--crnn', default='', help="path to crnn (to continue training)") |
| 41 | +-parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz') |
| 42 | ++parser.add_argument('--alphabet', type=str, default='富士通アドバンスエンジニアリング') |
| 43 | + parser.add_argument('--experiment', default=None, help='Where to store samples and models') |
| 44 | + parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed') |
| 45 | + parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test') |
| 46 | +@@ -64,16 +69,21 @@ else: |
| 47 | + sampler = None |
| 48 | + train_loader = torch.utils.data.DataLoader( |
| 49 | + train_dataset, batch_size=opt.batchSize, |
| 50 | +- shuffle=True, sampler=sampler, |
| 51 | ++ shuffle=True, |
| 52 | ++ # sampler=sampler, |
| 53 | + num_workers=int(opt.workers), |
| 54 | + collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio)) |
| 55 | + test_dataset = dataset.lmdbDataset( |
| 56 | + root=opt.valroot, transform=dataset.resizeNormalize((100, 32))) |
| 57 | + |
| 58 | +-nclass = len(opt.alphabet) + 1 |
| 59 | ++alphabet_u = opt.alphabet.decode('utf-8') |
| 60 | ++print(alphabet_u) |
| 61 | ++ |
| 62 | ++nclass = len(alphabet_u) + 1 |
| 63 | ++print("Number of classes: %s" % (nclass)) |
| 64 | + nc = 1 |
| 65 | + |
| 66 | +-converter = utils.strLabelConverter(opt.alphabet) |
| 67 | ++converter = utils.strLabelConverter(alphabet_u) |
| 68 | + criterion = CTCLoss() |
| 69 | + |
| 70 | + |
| 71 | +@@ -122,7 +132,7 @@ else: |
| 72 | + |
| 73 | + |
| 74 | + def val(net, dataset, criterion, max_iter=100): |
| 75 | +- print('Start val') |
| 76 | ++ print('================ Start val') |
| 77 | + |
| 78 | + for p in crnn.parameters(): |
| 79 | + p.requires_grad = False |
| 80 | +@@ -199,14 +209,17 @@ for epoch in range(opt.niter): |
| 81 | + i += 1 |
| 82 | + |
| 83 | + if i % opt.displayInterval == 0: |
| 84 | +- print('[%d/%d][%d/%d] Loss: %f' % |
| 85 | +- (epoch, opt.niter, i, len(train_loader), loss_avg.val())) |
| 86 | ++ logger.info("%s: Epoch: [%d/%d][%d/%d] Loss: %f" % |
| 87 | ++ (i, epoch, opt.niter, i, len(train_loader), loss_avg.val())) |
| 88 | + loss_avg.reset() |
| 89 | + |
| 90 | + if i % opt.valInterval == 0: |
| 91 | ++ logger.info("%s: Validating model" % (i)) |
| 92 | + val(crnn, test_dataset, criterion) |
| 93 | + |
| 94 | +- # do checkpointing |
| 95 | +- if i % opt.saveInterval == 0: |
| 96 | +- torch.save( |
| 97 | +- crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i)) |
| 98 | ++ if epoch % 10 == 0: |
| 99 | ++ # save model per 10 epoch |
| 100 | ++ logger.info("%s: Save model" % (i)) |
| 101 | ++ torch.save( |
| 102 | ++ crnn.state_dict(), '{0}/netCRNN_{1}.pth'.format(opt.experiment, epoch)) |
| 103 | ++ |
| 104 | +diff --git a/keys.py b/keys.py |
| 105 | +new file mode 100644 |
| 106 | +index 0000000..0a752cd |
| 107 | +--- /dev/null |
| 108 | ++++ b/keys.py |
| 109 | +@@ -0,0 +1,4 @@ |
| 110 | ++#!/usr/bin/python |
| 111 | ++# encoding: utf-8 |
| 112 | ++ |
| 113 | ++KEYS = "acbedgfihkjmlonqpsrutwvyxz" |
| 114 | +diff --git a/models/crnn.py b/models/crnn.py |
| 115 | +index 1dc2f60..4c16db0 100644 |
| 116 | +--- a/models/crnn.py |
| 117 | ++++ b/models/crnn.py |
| 118 | +@@ -1,3 +1,6 @@ |
| 119 | ++#!/usr/bin/python |
| 120 | ++# encoding: utf-8 |
| 121 | ++ |
| 122 | + import torch.nn as nn |
| 123 | + |
| 124 | + |
| 125 | +diff --git a/test/test_utils.py b/test/test_utils.py |
| 126 | +index 179fadf..c1d3c6b 100644 |
| 127 | +--- a/test/test_utils.py |
| 128 | ++++ b/test/test_utils.py |
| 129 | +@@ -11,6 +11,7 @@ sys.path.append("..") |
| 130 | + import utils |
| 131 | + sys.path = origin_path |
| 132 | + |
| 133 | ++from keys import KEYS |
| 134 | + |
| 135 | + def equal(a, b): |
| 136 | + if isinstance(a, torch.Tensor): |
| 137 | +@@ -29,7 +30,7 @@ def equal(a, b): |
| 138 | + class utilsTestCase(unittest.TestCase): |
| 139 | + |
| 140 | + def checkConverter(self): |
| 141 | +- encoder = utils.strLabelConverter('abcdefghijklmnopqrstuvwxyz') |
| 142 | ++ encoder = utils.strLabelConverter(KEYS) |
| 143 | + |
| 144 | + # Encode |
| 145 | + # trivial mode |
| 146 | +diff --git a/train.sh b/train.sh |
| 147 | +new file mode 100644 |
| 148 | +index 0000000..10ec6aa |
| 149 | +--- /dev/null |
| 150 | ++++ b/train.sh |
| 151 | +@@ -0,0 +1,11 @@ |
| 152 | ++#!/usr/bin/env bash |
| 153 | ++ |
| 154 | ++python crnn_main.py \ |
| 155 | ++ --trainroot /home/fae/workspace/tanarobot-SynthText/result/samples.lmdb \ |
| 156 | ++ --valroot /home/fae/workspace/tanarobot-SynthText/result/samples.lmdb \ |
| 157 | ++ --batchSize 32 \ |
| 158 | ++ --experiment tana \ |
| 159 | ++ --cuda --adam \ |
| 160 | ++ --saveInterval 100 \ |
| 161 | ++ --displayInterval 10 \ |
| 162 | ++ --niter 100 |
| 163 | +\ No newline at end of file |
| 164 | +diff --git a/utils.py b/utils.py |
| 165 | +index 31f04b2..de6a69c 100644 |
| 166 | +--- a/utils.py |
| 167 | ++++ b/utils.py |
| 168 | +@@ -5,7 +5,7 @@ import torch |
| 169 | + import torch.nn as nn |
| 170 | + from torch.autograd import Variable |
| 171 | + import collections |
| 172 | +- |
| 173 | ++from logger import logger |
| 174 | + |
| 175 | + class strLabelConverter(object): |
| 176 | + """Convert between str and label. |
| 177 | +@@ -42,11 +42,11 @@ class strLabelConverter(object): |
| 178 | + if isinstance(text, str): |
| 179 | + text = [ |
| 180 | + self.dict[char.lower() if self._ignore_case else char] |
| 181 | +- for char in text |
| 182 | ++ for char in text.decode("utf-8") |
| 183 | + ] |
| 184 | + length = [len(text)] |
| 185 | + elif isinstance(text, collections.Iterable): |
| 186 | +- length = [len(s) for s in text] |
| 187 | ++ length = [len(s.decode("utf-8")) for s in text] |
| 188 | + text = ''.join(text) |
| 189 | + text, _ = self.encode(text) |
| 190 | + return (torch.IntTensor(text), torch.IntTensor(length)) |
| 191 | +-- |
| 192 | +2.7.4 |
| 193 | + |
0 commit comments