Skip to content

Commit

Permalink
seq_len=512 prefill w/a
Browse files Browse the repository at this point in the history
  • Loading branch information
tianmu-li committed Dec 25, 2024
1 parent 69ffcb2 commit b19dbf8
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
) -> None:
super().__init__()
self.split_gate_up = split_gate_up
self.hidden_size = hidden_size
if self.split_gate_up:
self.gate_proj = ColumnParallelLinear(
input_size=hidden_size,
Expand Down Expand Up @@ -116,13 +117,19 @@ def __init__(
self.act_fn = SiluAndMul()

def forward(self, x, skip_seq_split=False):
batch_size = x.size(0)
seq_len = x.size(1)
if (seq_len*batch_size)%512==0:
x = x.view(-1,512,self.hidden_size)
if self.split_gate_up:
x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0]
else:
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
self.down_proj.skip_seq_split=skip_seq_split
x, _ = self.down_proj(x)
if (seq_len*batch_size)%512==0:
x = x.view(batch_size,seq_len,self.hidden_size)
return x


Expand Down

0 comments on commit b19dbf8

Please sign in to comment.