Skip to content

Commit 4a834d1

Browse files
authored
Apply utf-8
1 parent ab61ff1 commit 4a834d1

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

Diff for: 0001-Add-train-file.patch

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
From 1677e9c5c23be7f4ff46904e7128bf129aacbb93 Mon Sep 17 00:00:00 2001
2+
From: "vugia.truong" <[email protected]>
3+
Date: Mon, 18 Dec 2017 20:52:21 +0900
4+
Subject: [PATCH] Add train file crnn_main now can use utf-8 labels
5+
6+
---
7+
crnn_main.py | 35 ++++++++++++++++++++++++-----------
8+
keys.py | 4 ++++
9+
models/crnn.py | 3 +++
10+
test/test_utils.py | 3 ++-
11+
train.sh | 11 +++++++++++
12+
utils.py | 6 +++---
13+
6 files changed, 47 insertions(+), 15 deletions(-)
14+
create mode 100644 keys.py
15+
create mode 100644 train.sh
16+
17+
diff --git a/crnn_main.py b/crnn_main.py
18+
index 876ffb6..dac8c20 100644
19+
--- a/crnn_main.py
20+
+++ b/crnn_main.py
21+
@@ -1,3 +1,6 @@
22+
+#!/usr/bin/python
23+
+# encoding: utf-8
24+
+
25+
from __future__ import print_function
26+
import argparse
27+
import random
28+
@@ -14,6 +17,8 @@ import dataset
29+
30+
import models.crnn as crnn
31+
32+
+from logger import logger
33+
+
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument('--trainroot', required=True, help='path to dataset')
36+
parser.add_argument('--valroot', required=True, help='path to dataset')
37+
@@ -28,7 +33,7 @@ parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. de
38+
parser.add_argument('--cuda', action='store_true', help='enables cuda')
39+
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
40+
parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
41+
-parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz')
42+
+parser.add_argument('--alphabet', type=str, default='富士通アドバンスエンジニアリング')
43+
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
44+
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
45+
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
46+
@@ -64,16 +69,21 @@ else:
47+
sampler = None
48+
train_loader = torch.utils.data.DataLoader(
49+
train_dataset, batch_size=opt.batchSize,
50+
- shuffle=True, sampler=sampler,
51+
+ shuffle=True,
52+
+ # sampler=sampler,
53+
num_workers=int(opt.workers),
54+
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
55+
test_dataset = dataset.lmdbDataset(
56+
root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))
57+
58+
-nclass = len(opt.alphabet) + 1
59+
+alphabet_u = opt.alphabet.decode('utf-8')
60+
+print(alphabet_u)
61+
+
62+
+nclass = len(alphabet_u) + 1
63+
+print("Number of classes: %s" % (nclass))
64+
nc = 1
65+
66+
-converter = utils.strLabelConverter(opt.alphabet)
67+
+converter = utils.strLabelConverter(alphabet_u)
68+
criterion = CTCLoss()
69+
70+
71+
@@ -122,7 +132,7 @@ else:
72+
73+
74+
def val(net, dataset, criterion, max_iter=100):
75+
- print('Start val')
76+
+ print('================ Start val')
77+
78+
for p in crnn.parameters():
79+
p.requires_grad = False
80+
@@ -199,14 +209,17 @@ for epoch in range(opt.niter):
81+
i += 1
82+
83+
if i % opt.displayInterval == 0:
84+
- print('[%d/%d][%d/%d] Loss: %f' %
85+
- (epoch, opt.niter, i, len(train_loader), loss_avg.val()))
86+
+ logger.info("%s: Epoch: [%d/%d][%d/%d] Loss: %f" %
87+
+ (i, epoch, opt.niter, i, len(train_loader), loss_avg.val()))
88+
loss_avg.reset()
89+
90+
if i % opt.valInterval == 0:
91+
+ logger.info("%s: Validating model" % (i))
92+
val(crnn, test_dataset, criterion)
93+
94+
- # do checkpointing
95+
- if i % opt.saveInterval == 0:
96+
- torch.save(
97+
- crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
98+
+ if epoch % 10 == 0:
99+
+ # save model per 10 epoch
100+
+ logger.info("%s: Save model" % (i))
101+
+ torch.save(
102+
+ crnn.state_dict(), '{0}/netCRNN_{1}.pth'.format(opt.experiment, epoch))
103+
+
104+
diff --git a/keys.py b/keys.py
105+
new file mode 100644
106+
index 0000000..0a752cd
107+
--- /dev/null
108+
+++ b/keys.py
109+
@@ -0,0 +1,4 @@
110+
+#!/usr/bin/python
111+
+# encoding: utf-8
112+
+
113+
+KEYS = "acbedgfihkjmlonqpsrutwvyxz"
114+
diff --git a/models/crnn.py b/models/crnn.py
115+
index 1dc2f60..4c16db0 100644
116+
--- a/models/crnn.py
117+
+++ b/models/crnn.py
118+
@@ -1,3 +1,6 @@
119+
+#!/usr/bin/python
120+
+# encoding: utf-8
121+
+
122+
import torch.nn as nn
123+
124+
125+
diff --git a/test/test_utils.py b/test/test_utils.py
126+
index 179fadf..c1d3c6b 100644
127+
--- a/test/test_utils.py
128+
+++ b/test/test_utils.py
129+
@@ -11,6 +11,7 @@ sys.path.append("..")
130+
import utils
131+
sys.path = origin_path
132+
133+
+from keys import KEYS
134+
135+
def equal(a, b):
136+
if isinstance(a, torch.Tensor):
137+
@@ -29,7 +30,7 @@ def equal(a, b):
138+
class utilsTestCase(unittest.TestCase):
139+
140+
def checkConverter(self):
141+
- encoder = utils.strLabelConverter('abcdefghijklmnopqrstuvwxyz')
142+
+ encoder = utils.strLabelConverter(KEYS)
143+
144+
# Encode
145+
# trivial mode
146+
diff --git a/train.sh b/train.sh
147+
new file mode 100644
148+
index 0000000..10ec6aa
149+
--- /dev/null
150+
+++ b/train.sh
151+
@@ -0,0 +1,11 @@
152+
+#!/usr/bin/env bash
153+
+
154+
+python crnn_main.py \
155+
+ --trainroot /home/fae/workspace/tanarobot-SynthText/result/samples.lmdb \
156+
+ --valroot /home/fae/workspace/tanarobot-SynthText/result/samples.lmdb \
157+
+ --batchSize 32 \
158+
+ --experiment tana \
159+
+ --cuda --adam \
160+
+ --saveInterval 100 \
161+
+ --displayInterval 10 \
162+
+ --niter 100
163+
\ No newline at end of file
164+
diff --git a/utils.py b/utils.py
165+
index 31f04b2..de6a69c 100644
166+
--- a/utils.py
167+
+++ b/utils.py
168+
@@ -5,7 +5,7 @@ import torch
169+
import torch.nn as nn
170+
from torch.autograd import Variable
171+
import collections
172+
-
173+
+from logger import logger
174+
175+
class strLabelConverter(object):
176+
"""Convert between str and label.
177+
@@ -42,11 +42,11 @@ class strLabelConverter(object):
178+
if isinstance(text, str):
179+
text = [
180+
self.dict[char.lower() if self._ignore_case else char]
181+
- for char in text
182+
+ for char in text.decode("utf-8")
183+
]
184+
length = [len(text)]
185+
elif isinstance(text, collections.Iterable):
186+
- length = [len(s) for s in text]
187+
+ length = [len(s.decode("utf-8")) for s in text]
188+
text = ''.join(text)
189+
text, _ = self.encode(text)
190+
return (torch.IntTensor(text), torch.IntTensor(length))
191+
--
192+
2.7.4
193+

0 commit comments

Comments
 (0)