Skip to content

Commit

Permalink
Remove old code
Browse files Browse the repository at this point in the history
  • Loading branch information
tianmu-li committed Dec 18, 2024
1 parent d02c025 commit 2f34cbd
Showing 1 changed file with 49 additions and 34 deletions.
83 changes: 49 additions & 34 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,30 @@ def __init__(
split_size: int = 2
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.split_gate_up = True
if self.split_gate_up:
self.gate_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj",
)
self.up_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj"
)
else:
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
Expand All @@ -98,8 +115,11 @@ def __init__(
self.act_fn = SiluAndMul()

def forward(self, x, skip_seq_split=False):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
# if self.split_gate_up:
x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0]
# else:
# x, _ = self.gate_up_proj(x)
# x = self.act_fn(x)
self.down_proj.skip_seq_split=skip_seq_split
x, _ = self.down_proj(x)
return x
Expand Down Expand Up @@ -183,8 +203,7 @@ def __init__(
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
split_qk_v=self.split_qk_v,
prefix=f"{prefix}.qkv_proj"
)

self.o_proj = RowParallelLinear(
Expand Down Expand Up @@ -241,15 +260,15 @@ def forward(
skip_seq_split: bool = False,
**kwargs,
) -> torch.Tensor:
if self.split_qk_v:
# q, k, v, _ = self.qkv_proj(hidden_states)
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
else:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
# if self.split_qk_v:
# q, k, v, _ = self.qkv_proj(hidden_states)
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
# else:
# qkv, _ = self.qkv_proj(hidden_states)
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
# dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs)
self.o_proj.skip_seq_split=skip_seq_split
Expand Down Expand Up @@ -488,12 +507,8 @@ def __init__(self,
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

if is_hpu:
import os
self.config_hidden_layers = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))

self.split_qk_v = cache_config.split_qk_v
self.split_gate_up = True

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand All @@ -519,8 +534,6 @@ def forward(
residual = intermediate_tensors["residual"]

if is_hpu:
for i in range(self.start_layer, self.end_layer):
self.layers[i].self_attn.rotary_emb.prepare_cos_sin(positions)
import habana_frameworks.torch as htorch
htorch.core.mark_step()

Expand All @@ -542,16 +555,18 @@ def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
# (".gate_up_proj", ".gate_proj", 0),
# (".gate_up_proj", ".up_proj", 1),
]
if self.split_qk_v:
pass
# stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v"))
# stacked_params_mapping.append((".qkv_proj.k_proj", ".k_proj", "k"))
else:
if not self.split_qk_v:
stacked_params_mapping.append((".qkv_proj", ".q_proj", "q"))
stacked_params_mapping.append((".qkv_proj", ".k_proj", "k"))
stacked_params_mapping.append((".qkv_proj", ".v_proj", "v"))

if not self.split_gate_up:
stacked_params_mapping.append((".gate_up_proj", ".gate_proj", 0))
stacked_params_mapping.append((".gate_up_proj", ".up_proj", 1))

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
Expand Down

0 comments on commit 2f34cbd

Please sign in to comment.