Skip to content

Commit 5e51169

Browse files
authored
Fixed reduce dimension in ReduceState
1 parent 454a2f6 commit 5e51169

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

training_ptr_gen/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def __init__(self):
7272

7373
def forward(self, hidden):
7474
h, c = hidden # h, c dim = 2 x b x hidden_dim
75-
hidden_reduced_h = F.relu(self.reduce_h(h.view(-1, config.hidden_dim * 2)))
76-
hidden_reduced_c = F.relu(self.reduce_c(c.view(-1, config.hidden_dim * 2)))
75+
h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
76+
hidden_reduced_h = F.relu(self.reduce_h(h_in))
77+
c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
78+
hidden_reduced_c = F.relu(self.reduce_c(c_in))
7779

7880
return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim
7981

0 commit comments

Comments
 (0)