33
44import torch
55import torch .nn as nn
6+ from torch .autograd import Variable
67import collections
78
89
910class strLabelConverter (object ):
11+ """Convert between str and label.
1012
11- def __init__ (self , alphabet ):
13+ NOTE:
14+ Insert `blank` to the alphabet for CTC.
15+
16+ Args:
17+ alphabet (str): set of the possible characters.
18+ ignore_case (bool, default=True): whether or not to ignore all of the case.
19+ """
20+
21+ def __init__ (self , alphabet , ignore_case = True ):
22+ self ._ignore_case = ignore_case
23+ if self ._ignore_case :
24+ alphabet = alphabet .lower ()
1225 self .alphabet = alphabet + '-' # for `-1` index
1326
1427 self .dict = {}
1528 for i , char in enumerate (alphabet ):
1629 # NOTE: 0 is reserved for 'blank' required by wrap_ctc
1730 self .dict [char ] = i + 1
1831
19- def encode (self , text , depth = 0 ):
20- """Support batch or single str."""
32+ def encode (self , text ):
33+ """Support batch or single str.
34+
35+ Args:
36+ text (str or list of str): texts to convert.
37+
38+ Returns:
39+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
40+ torch.IntTensor [n]: length of each text.
41+ """
2142 if isinstance (text , str ):
22- text = [self .dict [char .lower ()] for char in text ]
43+ text = [
44+ self .dict [char .lower () if self ._ignore_case else char ]
45+ for char in text
46+ ]
2347 length = [len (text )]
2448 elif isinstance (text , collections .Iterable ):
2549 length = [len (s ) for s in text ]
2650 text = '' .join (text )
2751 text , _ = self .encode (text )
28-
29- if depth :
30- return text , len (text )
3152 return (torch .IntTensor (text ), torch .IntTensor (length ))
3253
3354 def decode (self , t , length , raw = False ):
55+ """Decode encoded texts back into strs.
56+
57+ Args:
58+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
59+ torch.IntTensor [n]: length of each text.
60+
61+ Raises:
62+ AssertionError: when the texts and its length does not match.
63+
64+ Returns:
65+ text (str or list of str): texts to convert.
66+ """
3467 if length .numel () == 1 :
3568 length = length [0 ]
36- t = t [: length ]
69+ assert t . numel () == length , "text with length: {} does not match declared length: {}" . format ( t . numel (), length )
3770 if raw :
3871 return '' .join ([self .alphabet [i - 1 ] for i in t ])
3972 else :
@@ -43,26 +76,35 @@ def decode(self, t, length, raw=False):
4376 char_list .append (self .alphabet [t [i ] - 1 ])
4477 return '' .join (char_list )
4578 else :
79+ # batch mode
80+ assert t .numel () == length .sum (), "texts with length: {} does not match declared length: {}" .format (t .numel (), length .sum ())
4681 texts = []
4782 index = 0
4883 for i in range (length .numel ()):
4984 l = length [i ]
50- texts .append (self .decode (
51- t [index :index + l ], torch .IntTensor ([l ]), raw = raw ))
85+ texts .append (
86+ self .decode (
87+ t [index :index + l ], torch .IntTensor ([l ]), raw = raw ))
5288 index += l
5389 return texts
5490
5591
5692class averager (object ):
93+ """Compute average for `torch.Variable` and `torch.Tensor`. """
5794
5895 def __init__ (self ):
5996 self .reset ()
6097
6198 def add (self , v ):
62- self .n_count += v .data .numel ()
63- # NOTE: not `+= v.sum()`, which will add a node in the compute graph,
64- # which lead to memory leak
65- self .sum += v .data .sum ()
99+ if isinstance (v , Variable ):
100+ count = v .data .numel ()
101+ v = v .data .sum ()
102+ elif isinstance (v , torch .Tensor ):
103+ count = v .numel ()
104+ v = v .sum ()
105+
106+ self .n_count += count
107+ self .sum += v
66108
67109 def reset (self ):
68110 self .n_count = 0
@@ -94,7 +136,8 @@ def loadData(v, data):
94136
95137def prettyPrint (v ):
96138 print ('Size {0}, Type: {1}' .format (str (v .size ()), v .data .type ()))
97- print ('| Max: %f | Min: %f | Mean: %f' % (v .max ().data [0 ], v .min ().data [0 ], v .mean ().data [0 ]))
139+ print ('| Max: %f | Min: %f | Mean: %f' % (v .max ().data [0 ], v .min ().data [0 ],
140+ v .mean ().data [0 ]))
98141
99142
100143def assureRatio (img ):
0 commit comments