diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 9ff2ad2277..c261cdd6f0 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -487,7 +487,7 @@ def forward( x = key_pad_mask x = x.expand(-1, self.head_count // self.parallel_gpu, -1) x = x.unsqueeze(3) - x = x.expand(-1, -1, -1, 128) + x = x.expand(-1, -1, -1, value.size(3)) value = value.masked_fill(x, 0) self.layer_cache[1]["keys"] = key