Skip to content

Commit b19dbf8

Browse files
committed
seq_len=512 prefill w/a
1 parent 69ffcb2 commit b19dbf8

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

vllm/model_executor/models/llama.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
) -> None:
7979
super().__init__()
8080
self.split_gate_up = split_gate_up
81+
self.hidden_size = hidden_size
8182
if self.split_gate_up:
8283
self.gate_proj = ColumnParallelLinear(
8384
input_size=hidden_size,
@@ -116,13 +117,19 @@ def __init__(
116117
self.act_fn = SiluAndMul()
117118

118119
def forward(self, x, skip_seq_split=False):
120+
batch_size = x.size(0)
121+
seq_len = x.size(1)
122+
if (seq_len*batch_size)%512==0:
123+
x = x.view(-1,512,self.hidden_size)
119124
if self.split_gate_up:
120125
x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0]
121126
else:
122127
x, _ = self.gate_up_proj(x)
123128
x = self.act_fn(x)
124129
self.down_proj.skip_seq_split=skip_seq_split
125130
x, _ = self.down_proj(x)
131+
if (seq_len*batch_size)%512==0:
132+
x = x.view(batch_size,seq_len,self.hidden_size)
126133
return x
127134

128135

0 commit comments

Comments
 (0)