File tree 2 files changed +6
-2
lines changed
python/llm/src/ipex_llm/transformers/models
2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -1011,8 +1011,10 @@ def llama_attention_forward_4_36_quantized(
1011
1011
kv_seq_len = key_states .shape [- 2 ]
1012
1012
1013
1013
if len (past_key_value .key_cache ) <= self .layer_idx :
1014
+ repeated_key_states = repeat_kv (key_states , self .num_key_value_groups )
1015
+ repeated_value_states = repeat_kv (value_states , self .num_key_value_groups )
1014
1016
attn_weights = torch .matmul (query_states ,
1015
- key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
1017
+ repeated_key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
1016
1018
1017
1019
if attn_weights .size () != (bsz , self .num_heads , q_len , kv_seq_len ):
1018
1020
invalidInputError (
@@ -1038,7 +1040,7 @@ def llama_attention_forward_4_36_quantized(
1038
1040
# upcast attention to fp32
1039
1041
attn_weights = nn .functional .softmax (attn_weights , dim = - 1 ,
1040
1042
dtype = torch .float32 ).to (query_states .dtype )
1041
- attn_output = torch .matmul (attn_weights , value_states )
1043
+ attn_output = torch .matmul (attn_weights , repeated_value_states )
1042
1044
if use_cache :
1043
1045
cache_kwargs = None
1044
1046
key_states , value_states = past_key_value .update (key_states , value_states ,
Original file line number Diff line number Diff line change @@ -395,6 +395,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
395
395
396
396
def fp16_fusion_check (proj , x , training ):
397
397
# only use fp16 fusion on PVC inference
398
+ if not hasattr (proj , "qtype" ):
399
+ return False
398
400
if proj .qtype != ggml_tensor_qtype ["fp16" ]:
399
401
return False
400
402
if proj .weight_type != 2 :
You can’t perform that action at this time.
0 commit comments