|
1 | 1 | import torch.nn as nn |
2 | | -import utils |
3 | 2 |
|
4 | 3 |
|
5 | 4 | class BidirectionalLSTM(nn.Module): |
6 | 5 |
|
7 | | - def __init__(self, nIn, nHidden, nOut, ngpu): |
| 6 | + def __init__(self, nIn, nHidden, nOut): |
8 | 7 | super(BidirectionalLSTM, self).__init__() |
9 | | - self.ngpu = ngpu |
10 | 8 |
|
11 | 9 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) |
12 | 10 | self.embedding = nn.Linear(nHidden * 2, nOut) |
13 | 11 |
|
14 | 12 | def forward(self, input): |
15 | | - recurrent, _ = utils.data_parallel( |
16 | | - self.rnn, input, self.ngpu) # [T, b, h * 2] |
17 | | - |
| 13 | + recurrent, _ = self.rnn(input) |
18 | 14 | T, b, h = recurrent.size() |
19 | 15 | t_rec = recurrent.view(T * b, h) |
20 | | - output = utils.data_parallel( |
21 | | - self.embedding, t_rec, self.ngpu) # [T * b, nOut] |
| 16 | + |
| 17 | + output = self.embedding(t_rec) # [T * b, nOut] |
22 | 18 | output = output.view(T, b, -1) |
23 | 19 |
|
24 | 20 | return output |
25 | 21 |
|
26 | 22 |
|
27 | 23 | class CRNN(nn.Module): |
28 | 24 |
|
29 | | - def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False): |
| 25 | + def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False): |
30 | 26 | super(CRNN, self).__init__() |
31 | | - self.ngpu = ngpu |
32 | 27 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' |
33 | 28 |
|
34 | 29 | ks = [3, 3, 3, 3, 3, 3, 2] |
@@ -57,31 +52,28 @@ def convRelu(i, batchNormalization=False): |
57 | 52 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 |
58 | 53 | convRelu(2, True) |
59 | 54 | convRelu(3) |
60 | | - cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2), |
61 | | - (2, 1), |
62 | | - (0, 1))) # 256x4x16 |
| 55 | + cnn.add_module('pooling{0}'.format(2), |
| 56 | + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 |
63 | 57 | convRelu(4, True) |
64 | 58 | convRelu(5) |
65 | | - cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), |
66 | | - (2, 1), |
67 | | - (0, 1))) # 512x2x16 |
| 59 | + cnn.add_module('pooling{0}'.format(3), |
| 60 | + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 |
68 | 61 | convRelu(6, True) # 512x1x16 |
69 | 62 |
|
70 | 63 | self.cnn = cnn |
71 | 64 | self.rnn = nn.Sequential( |
72 | | - BidirectionalLSTM(512, nh, nh, ngpu), |
73 | | - BidirectionalLSTM(nh, nh, nclass, ngpu) |
74 | | - ) |
| 65 | + BidirectionalLSTM(512, nh, nh), |
| 66 | + BidirectionalLSTM(nh, nh, nclass)) |
75 | 67 |
|
76 | 68 | def forward(self, input): |
77 | 69 | # conv features |
78 | | - conv = utils.data_parallel(self.cnn, input, self.ngpu) |
| 70 | + conv = self.cnn(input) |
79 | 71 | b, c, h, w = conv.size() |
80 | 72 | assert h == 1, "the height of conv must be 1" |
81 | 73 | conv = conv.squeeze(2) |
82 | 74 | conv = conv.permute(2, 0, 1) # [w, b, c] |
83 | 75 |
|
84 | 76 | # rnn features |
85 | | - output = utils.data_parallel(self.rnn, conv, self.ngpu) |
| 77 | + output = self.rnn(conv) |
86 | 78 |
|
87 | 79 | return output |
0 commit comments