Skip to content

Commit 4850b73

Browse files
committed
comments enhance, param check
1 parent f9dd5ce commit 4850b73

File tree

2 files changed

+70
-16
lines changed

2 files changed

+70
-16
lines changed

test/test_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,16 @@ def checkConverter(self):
5151

5252
# replicate mode
5353
result = encoder.decode(
54-
torch.IntTensor([5, 5, 0, 1, 0]), torch.IntTensor([4]))
54+
torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([4]))
5555
target = 'ea'
5656
self.assertTrue(equal(result, target))
5757

58+
# raise AssertionError
59+
def f():
60+
result = encoder.decode(
61+
torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([3]))
62+
self.assertRaises(AssertionError, f)
63+
5864
# batch mode
5965
result = encoder.decode(
6066
torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2]))
@@ -75,6 +81,11 @@ def checkAverager(self):
7581
acc.add(Variable(torch.Tensor([[5, 6]])))
7682
assert acc.val() == 3.5
7783

84+
acc = utils.averager()
85+
acc.add(torch.Tensor([1, 2]))
86+
acc.add(torch.Tensor([[5, 6]]))
87+
assert acc.val() == 3.5
88+
7889
def checkAssureRatio(self):
7990
img = torch.Tensor([[1], [3]]).view(1, 1, 2, 1)
8091
img = Variable(img)

utils.py

+58-15
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,70 @@
33

44
import torch
55
import torch.nn as nn
6+
from torch.autograd import Variable
67
import collections
78

89

910
class strLabelConverter(object):
11+
"""Convert between str and label.
1012
11-
def __init__(self, alphabet):
13+
NOTE:
14+
Insert `blank` to the alphabet for CTC.
15+
16+
Args:
17+
alphabet (str): set of the possible characters.
18+
ignore_case (bool, default=True): whether or not to ignore all of the case.
19+
"""
20+
21+
def __init__(self, alphabet, ignore_case=True):
22+
self._ignore_case = ignore_case
23+
if self._ignore_case:
24+
alphabet = alphabet.lower()
1225
self.alphabet = alphabet + '-' # for `-1` index
1326

1427
self.dict = {}
1528
for i, char in enumerate(alphabet):
1629
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
1730
self.dict[char] = i + 1
1831

19-
def encode(self, text, depth=0):
20-
"""Support batch or single str."""
32+
def encode(self, text):
33+
"""Support batch or single str.
34+
35+
Args:
36+
text (str or list of str): texts to convert.
37+
38+
Returns:
39+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
40+
torch.IntTensor [n]: length of each text.
41+
"""
2142
if isinstance(text, str):
22-
text = [self.dict[char.lower()] for char in text]
43+
text = [
44+
self.dict[char.lower() if self._ignore_case else char]
45+
for char in text
46+
]
2347
length = [len(text)]
2448
elif isinstance(text, collections.Iterable):
2549
length = [len(s) for s in text]
2650
text = ''.join(text)
2751
text, _ = self.encode(text)
28-
29-
if depth:
30-
return text, len(text)
3152
return (torch.IntTensor(text), torch.IntTensor(length))
3253

3354
def decode(self, t, length, raw=False):
55+
"""Decode encoded texts back into strs.
56+
57+
Args:
58+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
59+
torch.IntTensor [n]: length of each text.
60+
61+
Raises:
62+
AssertionError: when the texts and its length does not match.
63+
64+
Returns:
65+
text (str or list of str): texts to convert.
66+
"""
3467
if length.numel() == 1:
3568
length = length[0]
36-
t = t[:length]
69+
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
3770
if raw:
3871
return ''.join([self.alphabet[i - 1] for i in t])
3972
else:
@@ -43,26 +76,35 @@ def decode(self, t, length, raw=False):
4376
char_list.append(self.alphabet[t[i] - 1])
4477
return ''.join(char_list)
4578
else:
79+
# batch mode
80+
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
4681
texts = []
4782
index = 0
4883
for i in range(length.numel()):
4984
l = length[i]
50-
texts.append(self.decode(
51-
t[index:index + l], torch.IntTensor([l]), raw=raw))
85+
texts.append(
86+
self.decode(
87+
t[index:index + l], torch.IntTensor([l]), raw=raw))
5288
index += l
5389
return texts
5490

5591

5692
class averager(object):
93+
"""Compute average for `torch.Variable` and `torch.Tensor`. """
5794

5895
def __init__(self):
5996
self.reset()
6097

6198
def add(self, v):
62-
self.n_count += v.data.numel()
63-
# NOTE: not `+= v.sum()`, which will add a node in the compute graph,
64-
# which lead to memory leak
65-
self.sum += v.data.sum()
99+
if isinstance(v, Variable):
100+
count = v.data.numel()
101+
v = v.data.sum()
102+
elif isinstance(v, torch.Tensor):
103+
count = v.numel()
104+
v = v.sum()
105+
106+
self.n_count += count
107+
self.sum += v
66108

67109
def reset(self):
68110
self.n_count = 0
@@ -94,7 +136,8 @@ def loadData(v, data):
94136

95137
def prettyPrint(v):
96138
print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
97-
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], v.mean().data[0]))
139+
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0],
140+
v.mean().data[0]))
98141

99142

100143
def assureRatio(img):

0 commit comments

Comments
 (0)