From 6a46ad6a03d35241ba1055f8626f69ba8c36b973 Mon Sep 17 00:00:00 2001 From: Peyton <1115957667@qq.com> Date: Sat, 8 Jun 2024 18:12:48 +0800 Subject: [PATCH] Add phi3 support (#481) Co-authored-by: Casper --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/phi3.py | 128 +++++++++++++++++++++++++++++++++++++ awq/modules/fused/block.py | 72 +++++++++++++++++++++ awq/modules/fused/model.py | 67 +++++++++++++++++++ 6 files changed, 270 insertions(+) create mode 100644 awq/models/phi3.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index ceecf4f6..dff2fd76 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -17,4 +17,5 @@ from .gemma import GemmaAWQForCausalLM from .stablelm import StableLmAWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM +from .phi3 import Phi3AWQForCausalLM from .cohere import CohereAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index ad307630..ad057327 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -26,6 +26,7 @@ "gemma": GemmaAWQForCausalLM, "stablelm": StableLmAWQForCausalLM, "starcoder2": Starcoder2AWQForCausalLM, + "phi3": Phi3AWQForCausalLM, "cohere": CohereAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 4aff84d8..a942b507 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -70,6 +70,7 @@ "gemma": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM", + "phi3": "AutoModelForCausalLM", "cohere": "AutoModelForCausalLM", } diff --git a/awq/models/phi3.py b/awq/models/phi3.py new file mode 100644 index 00000000..1ad1a4b6 --- /dev/null +++ b/awq/models/phi3.py @@ -0,0 +1,128 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import Phi3Block +from awq.modules.fused.model import Phi3Model as AWQPhi3Model +from transformers.models.phi3.modeling_phi3 import ( + Phi3DecoderLayer as OldPhi3DecoderLayer, + Phi3ForCausalLM as OldPhi3ForCausalLM, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class Phi3AWQForCausalLM(BaseAWQForCausalLM): + layer_type = "Phi3DecoderLayer" + max_seq_len_key = "max_position_embeddings" + + @staticmethod + def fuse_layers(model: OldPhi3ForCausalLM): + fuser = Phi3Fuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldPhi3ForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: OldPhi3DecoderLayer): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldPhi3ForCausalLM, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: OldPhi3DecoderLayer, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[module.self_attn.qkv_proj], + inp=input_feat["self_attn.qkv_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + layers.append( + dict( + prev_op=module.self_attn.qkv_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_up_proj], + inp=input_feat["mlp.gate_up_proj"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.gate_up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers + + +class Phi3Fuser: + def __init__(self, model: OldPhi3ForCausalLM): + self.model = model + + self.phi3_blocks: List[Tuple[str, OldPhi3DecoderLayer]] = [ + (name, module) + for name, module in self.model.named_modules() + if "Phi3DecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + module: OldPhi3DecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = module.self_attn.qkv_proj + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, module.input_layernorm.variance_epsilon + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.variance_epsilon, + ) + blocks.append( + Phi3Block( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=module.mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_position_embeddings, + rope_theta=self.model.config.rope_theta, + rope_scaling=self.model.config.rope_scaling, + ) + ) + + self.model.model = AWQPhi3Model( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) \ No newline at end of file diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index 0726a06a..faefef3a 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -371,3 +371,75 @@ def forward( out = h_attn + h_mlp return out, None, past_key_value + + +class Phi3Block(nn.Module): + """ + Phi3Block is intended to be reused across blocks that have + an architecture that closely resembles Phi-3. + """ + + def __init__( + self, + hidden_size, + n_heads, + n_kv_heads, + qkv_layer, + o_proj, + mlp, + norm_1, + norm_2, + dev, + max_seq_len, + rope_theta=10000, + rope_scaling=None, + use_alibi=False, + head_dim=None, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = hidden_size // n_heads + + # To support models with separate head_dim + if head_dim: + self.head_dim = head_dim + + self.hidden_size = hidden_size + self.norm_1 = norm_1.to(dev) + self.attn = QuantAttentionFused( + self.hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + o_proj, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=use_alibi, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + head_dim=head_dim, + ).to(dev) + self.norm_2 = norm_2.to(dev) + self.mlp = mlp.to(dev) + self.device = dev + + def forward( + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, + ): + norm_out = self.norm_1(hidden_states) + attn_output, _, past_key_value = self.attn.forward( + hidden_states=norm_out, + past_key_value=past_key_value, + attention_mask=attention_mask, + ) + + h = hidden_states.to(attn_output.device) + attn_output + out = h + self.mlp.forward(self.norm_2(h)) + + return out, None, past_key_value \ No newline at end of file diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index 16264ed1..d1fe5437 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -11,6 +11,7 @@ FalconDecoderLayer, LlamaLikeBlock, MixtralBlock, + Phi3Block, CohereBlock, ) @@ -306,3 +307,69 @@ def forward( hidden_states=(), attentions=(), ) + +class Phi3Model(nn.Module): + """ + Phi3LikeModel is intended to be reused across models that have + an architecture that closely resembles Phi-3. + """ + + def __init__(self, vocab_size, blocks, embedding, norm): + super().__init__() + self.vocab_size = vocab_size + self.embedding = embedding + self.blocks: List[Phi3Block] = nn.ModuleList(blocks) + self.norm = norm + self.last_forward_num_tokens = 0 + + @property + def embed_tokens(self): + return self.embedding + + @property + def layers(self): + return self.blocks + + @torch.inference_mode() + def forward( + self, + input_ids: torch.Tensor, + attn_bias=None, + attention_mask=None, + is_causal=None, + *args, + **kwargs, + ): + input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( + input_ids, self.last_forward_num_tokens + ) + _bsz, seqlen = input_ids.shape + + fused_utils.prepare_cache(self.blocks, seqlen) + + h = self.embedding(input_ids) + + mask = fused_utils.prepare_attention_mask( + seqlen=seqlen, + start_pos=self.blocks[0].attn.start_pos, + device=input_ids.device, + type_as=h, + ) + + for layer in self.blocks: + h, mask = fused_utils.prepare_correct_devices( + layer, + h, + mask, + ) + h, _, _ = layer( + h, None, attention_mask=mask, is_causal=is_causal + ) + h = self.norm(h) + + return BaseModelOutputWithPast( + last_hidden_state=h, + past_key_values=None, + hidden_states=(), + attentions=(), + ) \ No newline at end of file