46
46
_ignore_causal_mask_sdpa ,
47
47
and_masks ,
48
48
causal_mask_function ,
49
+ find_packed_sequence_indices ,
49
50
padding_mask_function ,
50
51
prepare_padding_mask ,
51
52
)
@@ -206,27 +207,42 @@ def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None, output_si
206
207
return result
207
208
208
209
210
+ # Custom implementation of torch.linalg.matrix_norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
209
211
original_linal_norm = torch .linalg .norm
210
212
211
213
212
- # Custom implementation of torch.linalg.matrix_norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
213
214
def onnx_compatible_linalg_norm (x , ord = 2 , dim = None , keepdim = False , * , dtype = None , out = None ) -> torch .Tensor :
214
- """
215
- Custom implementation of torch.linalg.norm not using torch.linalg.matrix_norm, torch.norm or torch.linalg.norm.
216
- It only handles the case of matrix norm with ord=2, otherwise it uses the original implementation.
217
- """
215
+ if ord != 2 :
216
+ raise ValueError (
217
+ f"Only ord=2 is supported by onnx_compatible_linalg_norm, but got ord={ ord } . "
218
+ "Please extend this function to support other norms."
219
+ )
218
220
219
- if ord == 2 :
220
- if dim is None :
221
- dim = (- 2 , - 1 )
222
- norm = torch .sqrt (torch .sum (torch .square (x ), dim = dim , keepdim = keepdim ))
223
- if dtype is not None :
224
- norm = norm .to (dtype )
225
- if out is not None :
226
- out .copy_ (norm )
227
- return norm
221
+ if dim is None :
222
+ dim = (- 2 , - 1 )
223
+
224
+ norm = torch .sqrt (torch .sum (torch .square (x ), dim = dim , keepdim = keepdim ))
225
+ if dtype is not None :
226
+ norm = norm .to (dtype )
227
+ if out is not None :
228
+ out .copy_ (norm )
229
+
230
+ return norm
231
+
232
+
233
+ UNSUPPORTED_OPS_PATCHING_SPEC = [
234
+ PatchingSpec (torch .Tensor , "unfold" , onnx_compatible_unfold , torch .Tensor .unfold ),
235
+ PatchingSpec (torch .linalg , "norm" , onnx_compatible_linalg_norm , torch .linalg .norm ),
236
+ PatchingSpec (torch .Tensor , "repeat_interleave" , onnx_compatible_repeat_interleave , torch .Tensor .repeat_interleave ),
237
+ # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
238
+ PatchingSpec (torch .Tensor , "__len__" , lambda x : x .shape [0 ], torch .Tensor .__len__ ),
239
+ ]
228
240
229
- return original_linal_norm (x , ord = ord , dim = dim , keepdim = keepdim , dtype = dtype , out = out )
241
+
242
+ # A patched version of https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/masking_utils.py#L602
243
+ # That returns a tensor of zeros with the same shape as position_ids indicating no packed sequence indices.
244
+ def find_packed_sequence_indices_patched (position_ids : torch .Tensor ) -> torch .Tensor :
245
+ return torch .zeros_like (position_ids )
230
246
231
247
232
248
# Custom vectorized implementation of sdpa_mask without using vmap
@@ -276,15 +292,6 @@ def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
276
292
return mask
277
293
278
294
279
- UNSUPPORTED_OPS_PATCHING_SPEC = [
280
- PatchingSpec (torch .Tensor , "unfold" , onnx_compatible_unfold , torch .Tensor .unfold ),
281
- PatchingSpec (torch .linalg , "norm" , onnx_compatible_linalg_norm , original_linal_norm ),
282
- PatchingSpec (torch .Tensor , "repeat_interleave" , onnx_compatible_repeat_interleave , torch .Tensor .repeat_interleave ),
283
- # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
284
- PatchingSpec (torch .Tensor , "__len__" , lambda x : x .shape [0 ], torch .Tensor .__len__ ),
285
- ]
286
-
287
-
288
295
class ModelPatcher :
289
296
def __init__ (
290
297
self ,
@@ -418,8 +425,11 @@ def __enter__(self):
418
425
if is_transformers_version (">=" , "4.53" ):
419
426
self .original_sdpa_mask = ALL_MASK_ATTENTION_FUNCTIONS ["sdpa" ]
420
427
self .original_eager_mask = ALL_MASK_ATTENTION_FUNCTIONS ["eager" ]
428
+ self .original_find_packed_sequence_indices = find_packed_sequence_indices
429
+
421
430
ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , sdpa_mask_without_vmap )
422
431
ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , eager_mask_without_vmap )
432
+ transformers .masking_utils .find_packed_sequence_indices = find_packed_sequence_indices_patched
423
433
424
434
def __exit__ (self , exc_type , exc_value , traceback ):
425
435
self .restore_ops ()
@@ -431,6 +441,7 @@ def __exit__(self, exc_type, exc_value, traceback):
431
441
if is_transformers_version (">=" , "4.53" ):
432
442
ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa" , self .original_sdpa_mask )
433
443
ALL_MASK_ATTENTION_FUNCTIONS .register ("eager" , self .original_eager_mask )
444
+ transformers .masking_utils .find_packed_sequence_indices = self .original_find_packed_sequence_indices
434
445
435
446
def __call__ (self , * args , ** kwargs ):
436
447
if getattr (self ._model , self .orig_forward_name ) is self .orig_forward :
0 commit comments