Skip to content

Commit 772dcb8

Browse files
committed
Add utf-8 support
1 parent 4a834d1 commit 772dcb8

File tree

6 files changed

+47
-15
lines changed

6 files changed

+47
-15
lines changed

crnn_main.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#!/usr/bin/python
2+
# encoding: utf-8
3+
14
from __future__ import print_function
25
import argparse
36
import random
@@ -14,6 +17,8 @@
1417

1518
import models.crnn as crnn
1619

20+
from logger import logger
21+
1722
parser = argparse.ArgumentParser()
1823
parser.add_argument('--trainroot', required=True, help='path to dataset')
1924
parser.add_argument('--valroot', required=True, help='path to dataset')
@@ -28,7 +33,7 @@
2833
parser.add_argument('--cuda', action='store_true', help='enables cuda')
2934
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
3035
parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
31-
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz')
36+
parser.add_argument('--alphabet', type=str, default='あいうえおかきくけこ')
3237
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
3338
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
3439
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
@@ -64,16 +69,21 @@
6469
sampler = None
6570
train_loader = torch.utils.data.DataLoader(
6671
train_dataset, batch_size=opt.batchSize,
67-
shuffle=True, sampler=sampler,
72+
shuffle=True,
73+
# sampler=sampler,
6874
num_workers=int(opt.workers),
6975
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
7076
test_dataset = dataset.lmdbDataset(
7177
root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))
7278

73-
nclass = len(opt.alphabet) + 1
79+
alphabet_u = opt.alphabet.decode('utf-8')
80+
print(alphabet_u)
81+
82+
nclass = len(alphabet_u) + 1
83+
print("Number of classes: %s" % (nclass))
7484
nc = 1
7585

76-
converter = utils.strLabelConverter(opt.alphabet)
86+
converter = utils.strLabelConverter(alphabet_u)
7787
criterion = CTCLoss()
7888

7989

@@ -122,7 +132,7 @@ def weights_init(m):
122132

123133

124134
def val(net, dataset, criterion, max_iter=100):
125-
print('Start val')
135+
print('================ Start val')
126136

127137
for p in crnn.parameters():
128138
p.requires_grad = False
@@ -199,14 +209,17 @@ def trainBatch(net, criterion, optimizer):
199209
i += 1
200210

201211
if i % opt.displayInterval == 0:
202-
print('[%d/%d][%d/%d] Loss: %f' %
203-
(epoch, opt.niter, i, len(train_loader), loss_avg.val()))
212+
logger.info("%s: Epoch: [%d/%d][%d/%d] Loss: %f" %
213+
(i, epoch, opt.niter, i, len(train_loader), loss_avg.val()))
204214
loss_avg.reset()
205215

206216
if i % opt.valInterval == 0:
217+
logger.info("%s: Validating model" % (i))
207218
val(crnn, test_dataset, criterion)
208219

209-
# do checkpointing
210-
if i % opt.saveInterval == 0:
211-
torch.save(
212-
crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
220+
if epoch % 10 == 0:
221+
# save model per 10 epoch
222+
logger.info("%s: Save model" % (i))
223+
torch.save(
224+
crnn.state_dict(), '{0}/netCRNN_{1}.pth'.format(opt.experiment, epoch))
225+

keys.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/python
2+
# encoding: utf-8
3+
4+
KEYS = "acbedgfihkjmlonqpsrutwvyxz"

models/crnn.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#!/usr/bin/python
2+
# encoding: utf-8
3+
14
import torch.nn as nn
25

36

test/test_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import utils
1212
sys.path = origin_path
1313

14+
from keys import KEYS
1415

1516
def equal(a, b):
1617
if isinstance(a, torch.Tensor):
@@ -29,7 +30,7 @@ def equal(a, b):
2930
class utilsTestCase(unittest.TestCase):
3031

3132
def checkConverter(self):
32-
encoder = utils.strLabelConverter('abcdefghijklmnopqrstuvwxyz')
33+
encoder = utils.strLabelConverter(KEYS)
3334

3435
# Encode
3536
# trivial mode

train.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env bash
2+
3+
python crnn_main.py \
4+
--trainroot /home/gachiemchiep/workspace/tanarobot-SynthText/result/samples.lmdb \
5+
--valroot /home/gachiemchiep/workspace/tanarobot-SynthText/result/samples.lmdb \
6+
--batchSize 32 \
7+
--experiment tana \
8+
--cuda --adam \
9+
--saveInterval 100 \
10+
--displayInterval 10 \
11+
--niter 100

utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
from torch.autograd import Variable
77
import collections
8-
8+
from logger import logger
99

1010
class strLabelConverter(object):
1111
"""Convert between str and label.
@@ -42,11 +42,11 @@ def encode(self, text):
4242
if isinstance(text, str):
4343
text = [
4444
self.dict[char.lower() if self._ignore_case else char]
45-
for char in text
45+
for char in text.decode("utf-8")
4646
]
4747
length = [len(text)]
4848
elif isinstance(text, collections.Iterable):
49-
length = [len(s) for s in text]
49+
length = [len(s.decode("utf-8")) for s in text]
5050
text = ''.join(text)
5151
text, _ = self.encode(text)
5252
return (torch.IntTensor(text), torch.IntTensor(length))

0 commit comments

Comments
 (0)