Skip to content

Commit e908eec

Browse files
njhillXaenalt
authored andcommitted
Fix llama gqa attention bias (IBM#88)
To support IBM granite code 8b models Signed-off-by: Nick Hill <[email protected]>
1 parent 5953686 commit e908eec

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,9 @@ def _load_gqa(config, prefix: str, weights):
155155
assert config.hidden_size % config.num_attention_heads == 0
156156
assert config.num_attention_heads % weights.process_group.size() == 0
157157

158+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
158159
weight = weights.get_multi_weights_col(
159-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
160-
quantize=config.quantize,
161-
dim=0
160+
prefixes=prefixes, quantize=config.quantize, dim=0
162161
)
163162

164163
if config.quantize != "gptq":
@@ -172,7 +171,12 @@ def _load_gqa(config, prefix: str, weights):
172171
config.hidden_size,
173172
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
174173

175-
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
174+
if config.attention_bias:
175+
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
176+
else:
177+
bias = None
178+
179+
return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))
176180

177181

178182
class FlashLlamaAttention(torch.nn.Module):

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
156156
assert config.hidden_size % config.num_attention_heads == 0
157157
assert config.num_attention_heads % weights.process_group.size() == 0
158158

159+
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
159160
weight = weights.get_multi_weights_col(
160-
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
161-
quantize=config.quantize,
162-
dim=0
161+
prefixes=prefixes, quantize=config.quantize, dim=0
163162
)
164163

165164
if config.quantize != "gptq":
@@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
173172
config.hidden_size,
174173
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
175174

176-
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
175+
if config.attention_bias:
176+
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
177+
else:
178+
bias = None
179+
180+
return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))
177181

178182

179183
class PagedLlamaAttention(torch.nn.Module):

0 commit comments

Comments
 (0)