34
34
try :
35
35
# TODO: remove this after we figure out where in torchtune an `evaluate` module
36
36
# is being imported, which is being confused with huggingface's `evaluate``.
37
- import lm_eval # noqa
37
+ import lm_eval # noqa
38
38
except Exception :
39
39
pass
40
40
@@ -278,6 +278,9 @@ class TransformerArgs:
278
278
# For pipeline parallel
279
279
n_stages : int = 1
280
280
stage_idx : int = 0
281
+ # Optional biases
282
+ attention_bias : bool = False
283
+ feed_forward_bias : bool = False
281
284
282
285
def __post_init__ (self ):
283
286
if self .n_local_heads == - 1 :
@@ -394,7 +397,7 @@ def from_name(cls, name: str):
394
397
config = [
395
398
config
396
399
for config in known_model_params
397
- if config in str (name ).upper () or config in str (name )
400
+ if config . upper () in str (name ).upper () or config in str (name )
398
401
]
399
402
400
403
# We may have two or more configs matched (e.g., "7B" and
@@ -471,7 +474,7 @@ def build_model(self) -> nn.Module:
471
474
modules [name ] = module_class (TransformerArgs .from_params (config_args ))
472
475
else :
473
476
modules [name ] = module_class (** config_args )
474
-
477
+
475
478
# Temporary add extra params to the DeepFusionModel.
476
479
# TODO: Remove it once we can make fusion model configurable in model_param.
477
480
if recipe .fusion_class == DeepFusionModel :
@@ -730,16 +733,16 @@ def __init__(self, config: TransformerArgs):
730
733
731
734
# key, query, value projections for all heads, but in a batch
732
735
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
733
- # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False )
734
- self .wq = nn .Linear (config .dim , config .n_heads * config .head_dim , bias = False )
736
+ # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias )
737
+ self .wq = nn .Linear (config .dim , config .n_heads * config .head_dim , bias = config . attention_bias )
735
738
self .wk = nn .Linear (
736
- config .dim , config .n_local_heads * config .head_dim , bias = False
739
+ config .dim , config .n_local_heads * config .head_dim , bias = config . attention_bias
737
740
)
738
741
self .wv = nn .Linear (
739
- config .dim , config .n_local_heads * config .head_dim , bias = False
742
+ config .dim , config .n_local_heads * config .head_dim , bias = config . attention_bias
740
743
)
741
744
742
- self .wo = nn .Linear (config .dim , config .dim , bias = False )
745
+ self .wo = nn .Linear (config .dim , config .dim , bias = config . attention_bias )
743
746
self .kv_cache = None
744
747
745
748
self .n_heads = config .n_heads
@@ -766,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args):
766
769
# wv = state_dict.pop(prefix + "wv.weight")
767
770
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
768
771
769
- if prefix + "wqkv.weight" in state_dict :
770
- wqkv = state_dict .pop (prefix + "wqkv.weight" )
771
- q_size = self .n_heads * self .head_dim
772
- kv_size = self .n_local_heads * self .head_dim
773
- wq , wk , wv = torch .split (wqkv , (q_size , kv_size , kv_size ), dim = 0 )
774
- state_dict [prefix + "wq.weight" ] = wq
775
- state_dict [prefix + "wk.weight" ] = wk
776
- state_dict [prefix + "wv.weight" ] = wv
772
+ for tensor_suffix in ["weight" , "bias" ]:
773
+ wqkv_key = f"{ prefix } wqkv.{ tensor_suffix } "
774
+ if wqkv_key in state_dict :
775
+ wqkv = state_dict .pop (wqkv_key )
776
+ q_size = self .n_heads * self .head_dim
777
+ kv_size = self .n_local_heads * self .head_dim
778
+ wq , wk , wv = torch .split (wqkv , (q_size , kv_size , kv_size ), dim = 0 )
779
+ state_dict [f"{ prefix } wq.{ tensor_suffix } " ] = wq
780
+ state_dict [f"{ prefix } wk.{ tensor_suffix } " ] = wk
781
+ state_dict [f"{ prefix } wv.{ tensor_suffix } " ] = wv
777
782
778
783
return
779
784
@@ -852,9 +857,9 @@ def forward(
852
857
class FeedForward (nn .Module ):
853
858
def __init__ (self , config : TransformerArgs ) -> None :
854
859
super ().__init__ ()
855
- self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = False )
856
- self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = False )
857
- self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = False )
860
+ self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = config . feed_forward_bias )
861
+ self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = config . feed_forward_bias )
862
+ self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = config . feed_forward_bias )
858
863
859
864
def distribute (self , device_mesh : DeviceMesh ):
860
865
parallelize_module (self .w1 , device_mesh , ColwiseParallel ())
0 commit comments