Skip to content

Commit d41f0ea

Browse files
more specific version handling for find_packed_sequence_indices
1 parent 77dd30f commit d41f0ea

File tree

1 file changed

+10
-24
lines changed

1 file changed

+10
-24
lines changed

optimum/exporters/onnx/model_patcher.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,14 @@
4646
_ignore_causal_mask_sdpa,
4747
and_masks,
4848
causal_mask_function,
49-
find_packed_sequence_indices,
49+
eager_mask,
5050
padding_mask_function,
5151
prepare_padding_mask,
52+
sdpa_mask,
5253
)
5354
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
54-
else:
55-
causal_mask_function = None
55+
if is_transformers_version(">=", "4.53.1"):
56+
from transformers.masking_utils import find_packed_sequence_indices
5657

5758
if TYPE_CHECKING:
5859
from transformers import PreTrainedModel, TFPreTrainedModel
@@ -420,14 +421,11 @@ def __enter__(self):
420421
transformers.cache_utils.Cache = TraceableCache
421422

422423
if is_transformers_version(">=", "4.53"):
423-
self.original_sdpa_mask = ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
424-
self.original_eager_mask = ALL_MASK_ATTENTION_FUNCTIONS["eager"]
425-
self.original_find_packed_sequence_indices = find_packed_sequence_indices
426-
427424
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", sdpa_mask_without_vmap)
428425
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)
429426

430427
if is_transformers_version(">=", "4.53.1"):
428+
self.original_find_packed_sequence_indices = find_packed_sequence_indices
431429
transformers.masking_utils.find_packed_sequence_indices = find_packed_sequence_indices_patched
432430

433431
def __exit__(self, exc_type, exc_value, traceback):
@@ -438,8 +436,8 @@ def __exit__(self, exc_type, exc_value, traceback):
438436
transformers.cache_utils.Cache = self.original_cache_class
439437

440438
if is_transformers_version(">=", "4.53"):
441-
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", self.original_sdpa_mask)
442-
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", self.original_eager_mask)
439+
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", sdpa_mask)
440+
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask)
443441

444442
if is_transformers_version(">=", "4.53.1"):
445443
transformers.masking_utils.find_packed_sequence_indices = self.original_find_packed_sequence_indices
@@ -660,9 +658,12 @@ class DecoderModelPatcher(ModelPatcher):
660658
def __enter__(self):
661659
super().__enter__()
662660
if is_transformers_version(">=", "4.35"):
661+
self.original_make_causal_mask = AttentionMaskConverter._make_causal_mask
663662
AttentionMaskConverter._make_causal_mask = staticmethod(_make_causal_mask_patched)
664663

665664
if is_transformers_version(">=", "4.36"):
665+
self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended
666+
self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
666667
AttentionMaskConverter._unmask_unattended = staticmethod(_unmask_unattended_patched)
667668
patch_everywhere(
668669
"_prepare_4d_causal_attention_mask_for_sdpa",
@@ -683,21 +684,6 @@ def __exit__(self, exc_type, exc_value, traceback):
683684
module_name_prefix="transformers",
684685
)
685686

686-
def __init__(
687-
self,
688-
config: "OnnxConfig",
689-
model: Union["PreTrainedModel", "TFPreTrainedModel"],
690-
model_kwargs: Optional[Dict[str, Any]] = None,
691-
):
692-
super().__init__(config, model, model_kwargs)
693-
694-
if is_transformers_version(">=", "4.35"):
695-
self.original_make_causal_mask = AttentionMaskConverter._make_causal_mask
696-
697-
if is_transformers_version(">=", "4.36"):
698-
self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended
699-
self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
700-
701687

702688
def falcon_build_alibi_tensor_patched(
703689
attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype

0 commit comments

Comments
 (0)