diff --git a/bert_pytorch/model/transformer.py b/bert_pytorch/model/transformer.py index 288de26..d969e9f 100644 --- a/bert_pytorch/model/transformer.py +++ b/bert_pytorch/model/transformer.py @@ -27,5 +27,6 @@ def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): def forward(self, x, mask): x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) - x = self.output_sublayer(x, self.feed_forward) +# x = self.output_sublayer(x, self.feed_forward) + x = self.output_sublayer(x, lambda _x: self.feed_forward.forward(_x)) return self.dropout(x) diff --git a/bert_pytorch/model/utils/sublayer.py b/bert_pytorch/model/utils/sublayer.py index 6e36793..486da44 100644 --- a/bert_pytorch/model/utils/sublayer.py +++ b/bert_pytorch/model/utils/sublayer.py @@ -15,4 +15,6 @@ def __init__(self, size, dropout): def forward(self, x, sublayer): "Apply residual connection to any sublayer with the same size." - return x + self.dropout(sublayer(self.norm(x))) +# return x + self.dropout(sublayer(self.norm(x))) +# first residual connection and then layernorm + return self.norm(x + self.dropout(sublayer(x)))