Skip to content

Commit 9ba2eac

Browse files
committed
remove repeat_interleave since flex_decoding supports gqa
1 parent fce44e5 commit 9ba2eac

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

model.py

-2
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Opti
219219
if self.kv_cache is not None:
220220
k, v = self.kv_cache.update(input_pos, k, v)
221221

222-
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
223-
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
224222
y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))
225223

226224
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

0 commit comments

Comments
 (0)