3
3
4
4
import torch
5
5
import torch .nn as nn
6
+ from torch .autograd import Variable
6
7
import collections
7
8
8
9
9
10
class strLabelConverter (object ):
11
+ """Convert between str and label.
10
12
11
- def __init__ (self , alphabet ):
13
+ NOTE:
14
+ Insert `blank` to the alphabet for CTC.
15
+
16
+ Args:
17
+ alphabet (str): set of the possible characters.
18
+ ignore_case (bool, default=True): whether or not to ignore all of the case.
19
+ """
20
+
21
+ def __init__ (self , alphabet , ignore_case = True ):
22
+ self ._ignore_case = ignore_case
23
+ if self ._ignore_case :
24
+ alphabet = alphabet .lower ()
12
25
self .alphabet = alphabet + '-' # for `-1` index
13
26
14
27
self .dict = {}
15
28
for i , char in enumerate (alphabet ):
16
29
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
17
30
self .dict [char ] = i + 1
18
31
19
- def encode (self , text , depth = 0 ):
20
- """Support batch or single str."""
32
+ def encode (self , text ):
33
+ """Support batch or single str.
34
+
35
+ Args:
36
+ text (str or list of str): texts to convert.
37
+
38
+ Returns:
39
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
40
+ torch.IntTensor [n]: length of each text.
41
+ """
21
42
if isinstance (text , str ):
22
- text = [self .dict [char .lower ()] for char in text ]
43
+ text = [
44
+ self .dict [char .lower () if self ._ignore_case else char ]
45
+ for char in text
46
+ ]
23
47
length = [len (text )]
24
48
elif isinstance (text , collections .Iterable ):
25
49
length = [len (s ) for s in text ]
26
50
text = '' .join (text )
27
51
text , _ = self .encode (text )
28
-
29
- if depth :
30
- return text , len (text )
31
52
return (torch .IntTensor (text ), torch .IntTensor (length ))
32
53
33
54
def decode (self , t , length , raw = False ):
55
+ """Decode encoded texts back into strs.
56
+
57
+ Args:
58
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
59
+ torch.IntTensor [n]: length of each text.
60
+
61
+ Raises:
62
+ AssertionError: when the texts and its length does not match.
63
+
64
+ Returns:
65
+ text (str or list of str): texts to convert.
66
+ """
34
67
if length .numel () == 1 :
35
68
length = length [0 ]
36
- t = t [: length ]
69
+ assert t . numel () == length , "text with length: {} does not match declared length: {}" . format ( t . numel (), length )
37
70
if raw :
38
71
return '' .join ([self .alphabet [i - 1 ] for i in t ])
39
72
else :
@@ -43,26 +76,35 @@ def decode(self, t, length, raw=False):
43
76
char_list .append (self .alphabet [t [i ] - 1 ])
44
77
return '' .join (char_list )
45
78
else :
79
+ # batch mode
80
+ assert t .numel () == length .sum (), "texts with length: {} does not match declared length: {}" .format (t .numel (), length .sum ())
46
81
texts = []
47
82
index = 0
48
83
for i in range (length .numel ()):
49
84
l = length [i ]
50
- texts .append (self .decode (
51
- t [index :index + l ], torch .IntTensor ([l ]), raw = raw ))
85
+ texts .append (
86
+ self .decode (
87
+ t [index :index + l ], torch .IntTensor ([l ]), raw = raw ))
52
88
index += l
53
89
return texts
54
90
55
91
56
92
class averager (object ):
93
+ """Compute average for `torch.Variable` and `torch.Tensor`. """
57
94
58
95
def __init__ (self ):
59
96
self .reset ()
60
97
61
98
def add (self , v ):
62
- self .n_count += v .data .numel ()
63
- # NOTE: not `+= v.sum()`, which will add a node in the compute graph,
64
- # which lead to memory leak
65
- self .sum += v .data .sum ()
99
+ if isinstance (v , Variable ):
100
+ count = v .data .numel ()
101
+ v = v .data .sum ()
102
+ elif isinstance (v , torch .Tensor ):
103
+ count = v .numel ()
104
+ v = v .sum ()
105
+
106
+ self .n_count += count
107
+ self .sum += v
66
108
67
109
def reset (self ):
68
110
self .n_count = 0
@@ -94,7 +136,8 @@ def loadData(v, data):
94
136
95
137
def prettyPrint (v ):
96
138
print ('Size {0}, Type: {1}' .format (str (v .size ()), v .data .type ()))
97
- print ('| Max: %f | Min: %f | Mean: %f' % (v .max ().data [0 ], v .min ().data [0 ], v .mean ().data [0 ]))
139
+ print ('| Max: %f | Min: %f | Mean: %f' % (v .max ().data [0 ], v .min ().data [0 ],
140
+ v .mean ().data [0 ]))
98
141
99
142
100
143
def assureRatio (img ):
0 commit comments