46
46
_ignore_causal_mask_sdpa ,
47
47
and_masks ,
48
48
causal_mask_function ,
49
- find_packed_sequence_indices ,
49
+ eager_mask ,
50
50
padding_mask_function ,
51
51
prepare_padding_mask ,
52
+ sdpa_mask ,
52
53
)
53
54
from transformers .models .qwen3_moe .modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
54
- else :
55
- causal_mask_function = None
55
+ if is_transformers_version ( ">=" , "4.53.1" ) :
56
+ from transformers . masking_utils import find_packed_sequence_indices
56
57
57
58
if TYPE_CHECKING :
58
59
from transformers import PreTrainedModel , TFPreTrainedModel
@@ -420,14 +421,11 @@ def __enter__(self):
420
421
transformers .cache_utils .Cache = TraceableCache
421
422
422
423
if is_transformers_version (">=" , "4.53" ):
423
- self .original_sdpa_mask = ALL_MASK_ATTENTION_FUNCTIONS ["sdpa" ]
424
- self .original_eager_mask = ALL_MASK_ATTENTION_FUNCTIONS ["eager" ]
425
- self .original_find_packed_sequence_indices = find_packed_sequence_indices
426
-
427
424
ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , sdpa_mask_without_vmap )
428
425
ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , eager_mask_without_vmap )
429
426
430
427
if is_transformers_version (">=" , "4.53.1" ):
428
+ self .original_find_packed_sequence_indices = find_packed_sequence_indices
431
429
transformers .masking_utils .find_packed_sequence_indices = find_packed_sequence_indices_patched
432
430
433
431
def __exit__ (self , exc_type , exc_value , traceback ):
@@ -438,8 +436,8 @@ def __exit__(self, exc_type, exc_value, traceback):
438
436
transformers .cache_utils .Cache = self .original_cache_class
439
437
440
438
if is_transformers_version (">=" , "4.53" ):
441
- ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , self . original_sdpa_mask )
442
- ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , self . original_eager_mask )
439
+ ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , sdpa_mask )
440
+ ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , eager_mask )
443
441
444
442
if is_transformers_version (">=" , "4.53.1" ):
445
443
transformers .masking_utils .find_packed_sequence_indices = self .original_find_packed_sequence_indices
@@ -660,9 +658,12 @@ class DecoderModelPatcher(ModelPatcher):
660
658
def __enter__ (self ):
661
659
super ().__enter__ ()
662
660
if is_transformers_version (">=" , "4.35" ):
661
+ self .original_make_causal_mask = AttentionMaskConverter ._make_causal_mask
663
662
AttentionMaskConverter ._make_causal_mask = staticmethod (_make_causal_mask_patched )
664
663
665
664
if is_transformers_version (">=" , "4.36" ):
665
+ self .original_unmask_unattended = AttentionMaskConverter ._unmask_unattended
666
+ self .original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
666
667
AttentionMaskConverter ._unmask_unattended = staticmethod (_unmask_unattended_patched )
667
668
patch_everywhere (
668
669
"_prepare_4d_causal_attention_mask_for_sdpa" ,
@@ -683,21 +684,6 @@ def __exit__(self, exc_type, exc_value, traceback):
683
684
module_name_prefix = "transformers" ,
684
685
)
685
686
686
- def __init__ (
687
- self ,
688
- config : "OnnxConfig" ,
689
- model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
690
- model_kwargs : Optional [Dict [str , Any ]] = None ,
691
- ):
692
- super ().__init__ (config , model , model_kwargs )
693
-
694
- if is_transformers_version (">=" , "4.35" ):
695
- self .original_make_causal_mask = AttentionMaskConverter ._make_causal_mask
696
-
697
- if is_transformers_version (">=" , "4.36" ):
698
- self .original_unmask_unattended = AttentionMaskConverter ._unmask_unattended
699
- self .original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
700
-
701
687
702
688
def falcon_build_alibi_tensor_patched (
703
689
attention_mask : torch .Tensor , num_heads : int , dtype : torch .dtype
0 commit comments