From e9edb12bfbc1457a55cbef5df664211011eb98ca Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Tue, 21 May 2024 10:13:21 +0200 Subject: [PATCH] some code cleaning --- onmt/modules/multi_headed_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 9ff2ad2277..c261cdd6f0 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -487,7 +487,7 @@ def forward( x = key_pad_mask x = x.expand(-1, self.head_count // self.parallel_gpu, -1) x = x.unsqueeze(3) - x = x.expand(-1, -1, -1, 128) + x = x.expand(-1, -1, -1, value.size(3)) value = value.masked_fill(x, 0) self.layer_cache[1]["keys"] = key