Skip to content

Commit 6a2a2e8

Browse files
authored
Add support for Bias tensors (#1259)
* feat: Add support for attention and ff biases Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert): Add support for permuted kvq bias weights in HF conversion Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(model): Add support for bias wqkv tensor in Attention Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert): Remove prints and unnecessary dict get #1250 Branch: BiasTensors-1250 Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert): Remove unnecessary safe dict get #1250 Branch: BiasTensors-1250 Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 397967f commit 6a2a2e8

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

torchchat/cli/convert_hf_checkpoint.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,17 @@ def convert_hf_checkpoint(
8181
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
8282
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
8383
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
84+
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
85+
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
86+
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
87+
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
8488
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
8589
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
8690
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
8791
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
92+
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
93+
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
94+
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
8895
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
8996
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
9097
"model.norm.weight": "norm.weight",
@@ -93,11 +100,10 @@ def convert_hf_checkpoint(
93100
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
94101

95102
def permute(w, n_heads):
96-
dim = config.dim
97103
return (
98-
w.view(n_heads, 2, config.head_dim // 2, dim)
104+
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
99105
.transpose(1, 2)
100-
.reshape(config.head_dim * n_heads, dim)
106+
.reshape(w.shape)
101107
)
102108

103109
merged_result = {}
@@ -130,6 +136,7 @@ def load_safetensors():
130136
continue
131137
assert state_dict is not None, f"Unable to load tensors from {file}"
132138
merged_result.update(state_dict)
139+
133140
final_result = {}
134141
for key, value in merged_result.items():
135142
if "layers" in key:
@@ -145,16 +152,18 @@ def load_safetensors():
145152
final_result[new_key] = value
146153

147154
for key in tuple(final_result.keys()):
148-
if "wq" in key:
155+
if "wq.weight" in key or "wq.bias" in key:
156+
wk_key = key.replace("wq", "wk")
157+
wv_key = key.replace("wq", "wv")
149158
q = final_result[key]
150-
k = final_result[key.replace("wq", "wk")]
151-
v = final_result[key.replace("wq", "wv")]
159+
k = final_result[wk_key]
160+
v = final_result[wv_key]
152161
q = permute(q, config.n_heads)
153162
k = permute(k, config.n_local_heads)
154163
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
155164
del final_result[key]
156-
del final_result[key.replace("wq", "wk")]
157-
del final_result[key.replace("wq", "wv")]
165+
del final_result[wk_key]
166+
del final_result[wv_key]
158167
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
159168
torch.save(final_result, model_dir / "model.pth")
160169
print("Done.")

torchchat/model.py

+24-19
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
try:
3535
# TODO: remove this after we figure out where in torchtune an `evaluate` module
3636
# is being imported, which is being confused with huggingface's `evaluate``.
37-
import lm_eval # noqa
37+
import lm_eval # noqa
3838
except Exception:
3939
pass
4040

@@ -278,6 +278,9 @@ class TransformerArgs:
278278
# For pipeline parallel
279279
n_stages: int = 1
280280
stage_idx: int = 0
281+
# Optional biases
282+
attention_bias: bool = False
283+
feed_forward_bias: bool = False
281284

282285
def __post_init__(self):
283286
if self.n_local_heads == -1:
@@ -394,7 +397,7 @@ def from_name(cls, name: str):
394397
config = [
395398
config
396399
for config in known_model_params
397-
if config in str(name).upper() or config in str(name)
400+
if config.upper() in str(name).upper() or config in str(name)
398401
]
399402

400403
# We may have two or more configs matched (e.g., "7B" and
@@ -471,7 +474,7 @@ def build_model(self) -> nn.Module:
471474
modules[name] = module_class(TransformerArgs.from_params(config_args))
472475
else:
473476
modules[name] = module_class(**config_args)
474-
477+
475478
# Temporary add extra params to the DeepFusionModel.
476479
# TODO: Remove it once we can make fusion model configurable in model_param.
477480
if recipe.fusion_class == DeepFusionModel:
@@ -730,16 +733,16 @@ def __init__(self, config: TransformerArgs):
730733

731734
# key, query, value projections for all heads, but in a batch
732735
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
733-
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
734-
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
736+
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias)
737+
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias)
735738
self.wk = nn.Linear(
736-
config.dim, config.n_local_heads * config.head_dim, bias=False
739+
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
737740
)
738741
self.wv = nn.Linear(
739-
config.dim, config.n_local_heads * config.head_dim, bias=False
742+
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
740743
)
741744

742-
self.wo = nn.Linear(config.dim, config.dim, bias=False)
745+
self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias)
743746
self.kv_cache = None
744747

745748
self.n_heads = config.n_heads
@@ -766,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args):
766769
# wv = state_dict.pop(prefix + "wv.weight")
767770
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
768771

769-
if prefix + "wqkv.weight" in state_dict:
770-
wqkv = state_dict.pop(prefix + "wqkv.weight")
771-
q_size = self.n_heads * self.head_dim
772-
kv_size = self.n_local_heads * self.head_dim
773-
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
774-
state_dict[prefix + "wq.weight"] = wq
775-
state_dict[prefix + "wk.weight"] = wk
776-
state_dict[prefix + "wv.weight"] = wv
772+
for tensor_suffix in ["weight", "bias"]:
773+
wqkv_key = f"{prefix}wqkv.{tensor_suffix}"
774+
if wqkv_key in state_dict:
775+
wqkv = state_dict.pop(wqkv_key)
776+
q_size = self.n_heads * self.head_dim
777+
kv_size = self.n_local_heads * self.head_dim
778+
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
779+
state_dict[f"{prefix}wq.{tensor_suffix}"] = wq
780+
state_dict[f"{prefix}wk.{tensor_suffix}"] = wk
781+
state_dict[f"{prefix}wv.{tensor_suffix}"] = wv
777782

778783
return
779784

@@ -852,9 +857,9 @@ def forward(
852857
class FeedForward(nn.Module):
853858
def __init__(self, config: TransformerArgs) -> None:
854859
super().__init__()
855-
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
856-
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
857-
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
860+
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
861+
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias)
862+
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
858863

859864
def distribute(self, device_mesh: DeviceMesh):
860865
parallelize_module(self.w1, device_mesh, ColwiseParallel())

0 commit comments

Comments
 (0)