Skip to content

Commit c5e0165

Browse files
patch find_packed_sequence_indices as it's untraceable
1 parent f7b6ebd commit c5e0165

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

optimum/exporters/onnx/model_patcher.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
_ignore_causal_mask_sdpa,
4747
and_masks,
4848
causal_mask_function,
49+
find_packed_sequence_indices,
4950
padding_mask_function,
5051
prepare_padding_mask,
5152
)
@@ -206,27 +207,42 @@ def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_si
206207
return result
207208

208209

210+
# Custom implementation of torch.linalg.matrix_norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
209211
original_linal_norm = torch.linalg.norm
210212

211213

212-
# Custom implementation of torch.linalg.matrix_norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
213214
def onnx_compatible_linalg_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> torch.Tensor:
214-
"""
215-
Custom implementation of torch.linalg.norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
216-
It only handles the case of matrix norm with ord=2, otherwise it uses the original implementation.
217-
"""
215+
if ord != 2:
216+
raise ValueError(
217+
f"Only ord=2 is supported by onnx_compatible_linalg_norm, but got ord={ord}. "
218+
"Please extend this function to support other norms."
219+
)
218220

219-
if ord == 2:
220-
if dim is None:
221-
dim = (-2, -1)
222-
norm = torch.sqrt(torch.sum(torch.square(x), dim=dim, keepdim=keepdim))
223-
if dtype is not None:
224-
norm = norm.to(dtype)
225-
if out is not None:
226-
out.copy_(norm)
227-
return norm
221+
if dim is None:
222+
dim = (-2, -1)
223+
224+
norm = torch.sqrt(torch.sum(torch.square(x), dim=dim, keepdim=keepdim))
225+
if dtype is not None:
226+
norm = norm.to(dtype)
227+
if out is not None:
228+
out.copy_(norm)
229+
230+
return norm
231+
232+
233+
UNSUPPORTED_OPS_PATCHING_SPEC = [
234+
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
235+
PatchingSpec(torch.linalg, "norm", onnx_compatible_linalg_norm, torch.linalg.norm),
236+
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
237+
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
238+
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
239+
]
228240

229-
return original_linal_norm(x, ord=ord, dim=dim, keepdim=keepdim, dtype=dtype, out=out)
241+
242+
# A patched version of https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/masking_utils.py#L602
243+
# That returns a tensor of zeros with the same shape as position_ids indicating no packed sequence indices.
244+
def find_packed_sequence_indices_patched(position_ids: torch.Tensor) -> torch.Tensor:
245+
return torch.zeros_like(position_ids)
230246

231247

232248
# Custom vectorized implementation of sdpa_mask without using vmap
@@ -276,15 +292,6 @@ def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
276292
return mask
277293

278294

279-
UNSUPPORTED_OPS_PATCHING_SPEC = [
280-
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
281-
PatchingSpec(torch.linalg, "norm", onnx_compatible_linalg_norm, original_linal_norm),
282-
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
283-
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
284-
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
285-
]
286-
287-
288295
class ModelPatcher:
289296
def __init__(
290297
self,
@@ -418,8 +425,11 @@ def __enter__(self):
418425
if is_transformers_version(">=", "4.53"):
419426
self.original_sdpa_mask = ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
420427
self.original_eager_mask = ALL_MASK_ATTENTION_FUNCTIONS["eager"]
428+
self.original_find_packed_sequence_indices = find_packed_sequence_indices
429+
421430
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", sdpa_mask_without_vmap)
422431
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)
432+
transformers.masking_utils.find_packed_sequence_indices = find_packed_sequence_indices_patched
423433

424434
def __exit__(self, exc_type, exc_value, traceback):
425435
self.restore_ops()
@@ -431,6 +441,7 @@ def __exit__(self, exc_type, exc_value, traceback):
431441
if is_transformers_version(">=", "4.53"):
432442
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", self.original_sdpa_mask)
433443
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", self.original_eager_mask)
444+
transformers.masking_utils.find_packed_sequence_indices = self.original_find_packed_sequence_indices
434445

435446
def __call__(self, *args, **kwargs):
436447
if getattr(self._model, self.orig_forward_name) is self.orig_forward:

0 commit comments

Comments
 (0)