@@ -408,88 +408,175 @@ def infidelity(
408
408
>>> # Computes infidelity score for saliency maps
409
409
>>> infid = infidelity(net, perturb_fn, input, attribution)
410
410
"""
411
+ # perform argument formattings
412
+ inputs = _format_tensor_into_tuples (inputs ) # type: ignore
413
+ if baselines is not None :
414
+ baselines = _format_baseline (baselines , cast (Tuple [Tensor , ...], inputs ))
415
+ additional_forward_args = _format_additional_forward_args (additional_forward_args )
416
+ attributions = _format_tensor_into_tuples (attributions ) # type: ignore
411
417
412
- def _generate_perturbations (
413
- current_n_perturb_samples : int ,
414
- ) -> Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]:
415
- r"""
416
- The perturbations are generated for each example
417
- `current_n_perturb_samples` times.
418
+ # Make sure that inputs and corresponding attributions have matching sizes.
419
+ assert len (inputs ) == len (attributions ), (
420
+ """The number of tensors in the inputs and
421
+ attributions must match. Found number of tensors in the inputs is: {} and in the
422
+ attributions: {}"""
423
+ ).format (len (inputs ), len (attributions ))
424
+ for inp , attr in zip (inputs , attributions ):
425
+ assert inp .shape == attr .shape , (
426
+ """Inputs and attributions must have
427
+ matching shapes. One of the input tensor's shape is {} and the
428
+ attribution tensor's shape is: {}"""
429
+ # pyre-fixme[16]: Module `attr` has no attribute `shape`.
430
+ ).format (inp .shape , attr .shape )
418
431
419
- For performance reasons we are not calling `perturb_func` on each example but
420
- on a batch that contains `current_n_perturb_samples`
421
- repeated instances per example.
422
- """
432
+ bsz = inputs [0 ].size (0 )
423
433
424
- # pyre-fixme[3]: Return type must be annotated.
425
- def call_perturb_func ():
426
- r""" """
427
- baselines_pert = None
428
- inputs_pert : Union [Tensor , Tuple [Tensor , ...]]
429
- if len (inputs_expanded ) == 1 :
430
- inputs_pert = inputs_expanded [0 ]
431
- if baselines_expanded is not None :
432
- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
433
- # parameter.
434
- baselines_pert = cast (Tuple , baselines_expanded )[0 ]
435
- else :
436
- inputs_pert = inputs_expanded
437
- baselines_pert = baselines_expanded
438
- return (
439
- perturb_func (inputs_pert , baselines_pert )
440
- if baselines_pert is not None
441
- else perturb_func (inputs_pert )
442
- )
434
+ _next_infidelity_tensors = _make_next_infidelity_tensors_func (
435
+ forward_func ,
436
+ bsz ,
437
+ perturb_func ,
438
+ inputs ,
439
+ baselines ,
440
+ attributions ,
441
+ additional_forward_args ,
442
+ target ,
443
+ normalize ,
444
+ )
445
+
446
+ with torch .no_grad ():
447
+ # if not normalize, directly return aggrgated MSE ((a-b)^2,)
448
+ # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
449
+ agg_tensors = _divide_and_aggregate_metrics (
450
+ cast (Tuple [Tensor , ...], inputs ),
451
+ n_perturb_samples ,
452
+ _next_infidelity_tensors ,
453
+ agg_func = _sum_infidelity_tensors ,
454
+ max_examples_per_batch = max_examples_per_batch ,
455
+ )
456
+
457
+ if normalize :
458
+ beta_num = agg_tensors [1 ]
459
+ beta_denorm = agg_tensors [0 ]
443
460
444
- inputs_expanded = tuple (
445
- torch .repeat_interleave (input , current_n_perturb_samples , dim = 0 )
446
- for input in inputs
461
+ beta = safe_div (beta_num , beta_denorm )
462
+
463
+ infidelity_values = (
464
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
465
+ # `int`.
466
+ beta ** 2 * agg_tensors [0 ]
467
+ - 2 * beta * agg_tensors [1 ]
468
+ + agg_tensors [2 ]
447
469
)
470
+ else :
471
+ infidelity_values = agg_tensors [0 ]
448
472
449
- baselines_expanded = baselines
450
- if baselines is not None :
451
- baselines_expanded = tuple (
452
- (
453
- baseline .repeat_interleave (current_n_perturb_samples , dim = 0 )
454
- if isinstance (baseline , torch .Tensor )
455
- and baseline .shape [0 ] == input .shape [0 ]
456
- and baseline .shape [0 ] > 1
457
- else baseline
458
- )
473
+ infidelity_values /= n_perturb_samples
474
+
475
+ return infidelity_values
476
+
477
+
478
+ def _generate_perturbations (
479
+ current_n_perturb_samples : int ,
480
+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
481
+ perturb_func : Callable ,
482
+ inputs : TensorOrTupleOfTensorsGeneric ,
483
+ baselines : BaselineType ,
484
+ ) -> Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]:
485
+ r"""
486
+ The perturbations are generated for each example
487
+ `current_n_perturb_samples` times.
488
+
489
+ For performance reasons we are not calling `perturb_func` on each example but
490
+ on a batch that contains `current_n_perturb_samples`
491
+ repeated instances per example.
492
+ """
493
+
494
+ # pyre-fixme[3]: Return type must be annotated.
495
+ def call_perturb_func ():
496
+ r""" """
497
+ baselines_pert = None
498
+ inputs_pert : Union [Tensor , Tuple [Tensor , ...]]
499
+ if len (inputs_expanded ) == 1 :
500
+ inputs_pert = inputs_expanded [0 ]
501
+ if baselines_expanded is not None :
459
502
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
460
503
# parameter.
461
- for input , baseline in zip (inputs , cast (Tuple , baselines ))
504
+ baselines_pert = cast (Tuple , baselines_expanded )[0 ]
505
+ else :
506
+ inputs_pert = inputs_expanded
507
+ baselines_pert = baselines_expanded
508
+ return (
509
+ perturb_func (inputs_pert , baselines_pert )
510
+ if baselines_pert is not None
511
+ else perturb_func (inputs_pert )
512
+ )
513
+
514
+ inputs_expanded = tuple (
515
+ torch .repeat_interleave (input , current_n_perturb_samples , dim = 0 )
516
+ for input in inputs
517
+ )
518
+
519
+ baselines_expanded = baselines
520
+ if baselines is not None :
521
+ baselines_expanded = tuple (
522
+ (
523
+ baseline .repeat_interleave (current_n_perturb_samples , dim = 0 )
524
+ if isinstance (baseline , torch .Tensor )
525
+ and baseline .shape [0 ] == input .shape [0 ]
526
+ and baseline .shape [0 ] > 1
527
+ else baseline
462
528
)
529
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
530
+ # parameter.
531
+ for input , baseline in zip (inputs , cast (Tuple , baselines ))
532
+ )
463
533
464
- return call_perturb_func ()
465
-
466
- def _validate_inputs_and_perturbations (
467
- inputs : Tuple [Tensor , ...],
468
- inputs_perturbed : Tuple [Tensor , ...],
469
- perturbations : Tuple [Tensor , ...],
470
- ) -> None :
471
- # asserts the sizes of the perturbations and inputs
472
- assert len (perturbations ) == len (inputs ), (
473
- """The number of perturbed
474
- inputs and corresponding perturbations must have the same number of
475
- elements. Found number of inputs is: {} and perturbations:
476
- {}"""
477
- ).format (len (perturbations ), len (inputs ))
478
-
479
- # asserts the shapes of the perturbations and perturbed inputs
480
- for perturb , input_perturbed in zip (perturbations , inputs_perturbed ):
481
- assert perturb [0 ].shape == input_perturbed [0 ].shape , (
482
- """Perturbed input
483
- and corresponding perturbation must have the same shape and
484
- dimensionality. Found perturbation shape is: {} and the input shape
485
- is: {}"""
486
- ).format (perturb [0 ].shape , input_perturbed [0 ].shape )
534
+ return call_perturb_func ()
535
+
536
+
537
+ def _validate_inputs_and_perturbations (
538
+ inputs : Tuple [Tensor , ...],
539
+ inputs_perturbed : Tuple [Tensor , ...],
540
+ perturbations : Tuple [Tensor , ...],
541
+ ) -> None :
542
+ # asserts the sizes of the perturbations and inputs
543
+ assert len (perturbations ) == len (inputs ), (
544
+ """The number of perturbed
545
+ inputs and corresponding perturbations must have the same number of
546
+ elements. Found number of inputs is: {} and perturbations:
547
+ {}"""
548
+ ).format (len (perturbations ), len (inputs ))
549
+
550
+ # asserts the shapes of the perturbations and perturbed inputs
551
+ for perturb , input_perturbed in zip (perturbations , inputs_perturbed ):
552
+ assert perturb [0 ].shape == input_perturbed [0 ].shape , (
553
+ """Perturbed input
554
+ and corresponding perturbation must have the same shape and
555
+ dimensionality. Found perturbation shape is: {} and the input shape
556
+ is: {}"""
557
+ ).format (perturb [0 ].shape , input_perturbed [0 ].shape )
558
+
559
+
560
+ def _make_next_infidelity_tensors_func (
561
+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
562
+ forward_func : Callable ,
563
+ bsz : int ,
564
+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
565
+ perturb_func : Callable ,
566
+ inputs : TensorOrTupleOfTensorsGeneric ,
567
+ baselines : BaselineType ,
568
+ attributions : TensorOrTupleOfTensorsGeneric ,
569
+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
570
+ additional_forward_args : Any = None ,
571
+ target : TargetType = None ,
572
+ normalize : bool = False ,
573
+ ) -> Callable [[int ], Union [Tuple [Tensor ], Tuple [Tensor , Tensor , Tensor ]]]:
487
574
488
575
def _next_infidelity_tensors (
489
576
current_n_perturb_samples : int ,
490
577
) -> Union [Tuple [Tensor ], Tuple [Tensor , Tensor , Tensor ]]:
491
578
perturbations , inputs_perturbed = _generate_perturbations (
492
- current_n_perturb_samples
579
+ current_n_perturb_samples , perturb_func , inputs , baselines
493
580
)
494
581
495
582
perturbations = _format_tensor_into_tuples (perturbations )
@@ -564,60 +651,10 @@ def _next_infidelity_tensors(
564
651
# returns (a-b)^2 if no need to normalize
565
652
return ((attr_times_perturb_sums - perturbed_fwd_diffs ).pow (2 ).sum (- 1 ),)
566
653
567
- # pyre-fixme[3]: Return type must be annotated.
568
- # pyre-fixme[2]: Parameter must be annotated.
569
- def _sum_infidelity_tensors (agg_tensors , tensors ):
570
- return tuple (agg_t + t for agg_t , t in zip (agg_tensors , tensors ))
654
+ return _next_infidelity_tensors
571
655
572
- # perform argument formattings
573
- inputs = _format_tensor_into_tuples (inputs ) # type: ignore
574
- if baselines is not None :
575
- baselines = _format_baseline (baselines , cast (Tuple [Tensor , ...], inputs ))
576
- additional_forward_args = _format_additional_forward_args (additional_forward_args )
577
- attributions = _format_tensor_into_tuples (attributions ) # type: ignore
578
656
579
- # Make sure that inputs and corresponding attributions have matching sizes.
580
- assert len (inputs ) == len (attributions ), (
581
- """The number of tensors in the inputs and
582
- attributions must match. Found number of tensors in the inputs is: {} and in the
583
- attributions: {}"""
584
- ).format (len (inputs ), len (attributions ))
585
- for inp , attr in zip (inputs , attributions ):
586
- assert inp .shape == attr .shape , (
587
- """Inputs and attributions must have
588
- matching shapes. One of the input tensor's shape is {} and the
589
- attribution tensor's shape is: {}"""
590
- # pyre-fixme[16]: Module `attr` has no attribute `shape`.
591
- ).format (inp .shape , attr .shape )
592
-
593
- bsz = inputs [0 ].size (0 )
594
- with torch .no_grad ():
595
- # if not normalize, directly return aggrgated MSE ((a-b)^2,)
596
- # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
597
- agg_tensors = _divide_and_aggregate_metrics (
598
- cast (Tuple [Tensor , ...], inputs ),
599
- n_perturb_samples ,
600
- _next_infidelity_tensors ,
601
- agg_func = _sum_infidelity_tensors ,
602
- max_examples_per_batch = max_examples_per_batch ,
603
- )
604
-
605
- if normalize :
606
- beta_num = agg_tensors [1 ]
607
- beta_denorm = agg_tensors [0 ]
608
-
609
- beta = safe_div (beta_num , beta_denorm )
610
-
611
- infidelity_values = (
612
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
613
- # `int`.
614
- beta ** 2 * agg_tensors [0 ]
615
- - 2 * beta * agg_tensors [1 ]
616
- + agg_tensors [2 ]
617
- )
618
- else :
619
- infidelity_values = agg_tensors [0 ]
620
-
621
- infidelity_values /= n_perturb_samples
622
-
623
- return infidelity_values
657
+ # pyre-fixme[3]: Return type must be annotated.
658
+ # pyre-fixme[2]: Parameter must be annotated.
659
+ def _sum_infidelity_tensors (agg_tensors , tensors ):
660
+ return tuple (agg_t + t for agg_t , t in zip (agg_tensors , tensors ))
0 commit comments