diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e631161cada89..65e70c082ee18 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -78,6 +78,7 @@ def __init__( ) -> None: super().__init__() self.split_gate_up = split_gate_up + self.hidden_size = hidden_size if self.split_gate_up: self.gate_proj = ColumnParallelLinear( input_size=hidden_size, @@ -116,6 +117,10 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x, skip_seq_split=False): + batch_size = x.size(0) + seq_len = x.size(1) + if (seq_len*batch_size)%512==0: + x = x.view(-1,512,self.hidden_size) if self.split_gate_up: x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0] else: @@ -123,6 +128,8 @@ def forward(self, x, skip_seq_split=False): x = self.act_fn(x) self.down_proj.skip_seq_split=skip_seq_split x, _ = self.down_proj(x) + if (seq_len*batch_size)%512==0: + x = x.view(batch_size,seq_len,self.hidden_size) return x