Skip to content

Commit 8f45e22

Browse files
authored
fix llama2 (#10710)
1 parent e10040b commit 8f45e22

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

python/llm/src/ipex_llm/transformers/models/llama.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1011,8 +1011,10 @@ def llama_attention_forward_4_36_quantized(
10111011
kv_seq_len = key_states.shape[-2]
10121012

10131013
if len(past_key_value.key_cache) <= self.layer_idx:
1014+
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
1015+
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
10141016
attn_weights = torch.matmul(query_states,
1015-
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
1017+
repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
10161018

10171019
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
10181020
invalidInputError(
@@ -1038,7 +1040,7 @@ def llama_attention_forward_4_36_quantized(
10381040
# upcast attention to fp32
10391041
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
10401042
dtype=torch.float32).to(query_states.dtype)
1041-
attn_output = torch.matmul(attn_weights, value_states)
1043+
attn_output = torch.matmul(attn_weights, repeated_value_states)
10421044
if use_cache:
10431045
cache_kwargs = None
10441046
key_states, value_states = past_key_value.update(key_states, value_states,

python/llm/src/ipex_llm/transformers/models/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
395395

396396
def fp16_fusion_check(proj, x, training):
397397
# only use fp16 fusion on PVC inference
398+
if not hasattr(proj, "qtype"):
399+
return False
398400
if proj.qtype != ggml_tensor_qtype["fp16"]:
399401
return False
400402
if proj.weight_type != 2:

0 commit comments

Comments
 (0)